Skip to content

Commit 7ffb6fb

Browse files
committed
Implemented workaround for memory and core utilization regressions
1 parent a8b12f3 commit 7ffb6fb

File tree

1 file changed

+49
-17
lines changed

1 file changed

+49
-17
lines changed

src/mdio/segy/_workers.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy as np
1111
from segy import SegyFile
1212

13+
from mdio.schemas import ScalarType
14+
1315
if TYPE_CHECKING:
1416
from segy.arrays import HeaderArray
1517
from segy.config import SegySettings
@@ -19,7 +21,11 @@
1921

2022
from mdio.core.storage_location import StorageLocation
2123

24+
from xarray import Variable
25+
from zarr.core.config import config as zarr_config
26+
2227
from mdio.constants import UINT32_MAX
28+
from mdio.schemas.v1.dataset_serializer import _get_fill_value
2329
from mdio.schemas.v1.stats import CenteredBinHistogram
2430
from mdio.schemas.v1.stats import SummaryStatistics
2531

@@ -97,6 +103,12 @@ def trace_worker( # noqa: PLR0913
97103
Returns:
98104
SummaryStatistics object containing statistics about the written traces.
99105
"""
106+
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__MAX_WORKERS`
107+
# environment variable.
108+
# Since the release of the Zarr 3 engine, it will default to many threads.
109+
# This can cause resource contention and unpredicted memory consumption.
110+
zarr_config.set({"threading.max_workers": 1})
111+
100112
if not dataset.trace_mask.any():
101113
return None
102114

@@ -109,33 +121,53 @@ def trace_worker( # noqa: PLR0913
109121
live_trace_indexes = grid_map[not_null].tolist()
110122
traces = segy_file.trace[live_trace_indexes]
111123

124+
header_key = "headers"
125+
112126
# Get subset of the dataset that has not yet been saved
113127
# The headers might not be present in the dataset
114-
# TODO(Dmitriy Repin): Check, should we overwrite the 'dataset' instead to save the memory
115-
# https://github.com/TGSAI/mdio-python/issues/584
116-
if "headers" in dataset.data_vars:
117-
ds_to_write = dataset[[data_variable_name, "headers"]]
118-
ds_to_write = ds_to_write.reset_coords()
119-
120-
ds_to_write["headers"].data[not_null] = traces.header
121-
ds_to_write["headers"].data[~not_null] = 0
122-
else:
123-
ds_to_write = dataset[[data_variable_name]]
124-
ds_to_write = ds_to_write.reset_coords()
125-
126-
ds_to_write[data_variable_name].data[not_null] = traces.sample
128+
worker_variables = [data_variable_name]
129+
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
130+
worker_variables.append(header_key)
131+
132+
ds_to_write = dataset[worker_variables]
133+
134+
if header_key in worker_variables:
135+
# Create temporary array for headers with the correct shape
136+
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
137+
tmp_headers = np.zeros_like(dataset[header_key])
138+
tmp_headers[not_null] = traces.header
139+
# Create a new Variable object to avoid copying the temporary array
140+
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
141+
# but Xarray appears to be copying memory instead of doing direct assignment.
142+
# TODO(BrianMichell): #614 Look into this further.
143+
ds_to_write[header_key] = Variable(
144+
ds_to_write[header_key].dims,
145+
tmp_headers,
146+
attrs=ds_to_write[header_key].attrs,
147+
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
148+
)
149+
150+
data_variable = ds_to_write[data_variable_name]
151+
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
152+
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
153+
tmp_samples[not_null] = traces.sample
154+
# Create a new Variable object to avoid copying the temporary array
155+
# TODO(BrianMichell): #614 Look into this further.
156+
ds_to_write[data_variable_name] = Variable(
157+
ds_to_write[data_variable_name].dims,
158+
tmp_samples,
159+
attrs=ds_to_write[data_variable_name].attrs,
160+
encoding=ds_to_write[data_variable_name].encoding, # Not strictly necessary, but safer than not doing it.
161+
)
127162

128-
out_path = output_location.uri
129-
ds_to_write.to_zarr(out_path, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)
163+
ds_to_write.to_zarr(output_location.uri, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)
130164

131165
histogram = CenteredBinHistogram(bin_centers=[], counts=[])
132166
return SummaryStatistics(
133167
count=traces.sample.size,
134168
min=traces.sample.min(),
135169
max=traces.sample.max(),
136170
sum=traces.sample.sum(),
137-
# TODO(Altay): Look at how to do the sum squares statistic correctly
138-
# https://github.com/TGSAI/mdio-python/issues/581
139171
sum_squares=(traces.sample**2).sum(),
140172
histogram=histogram,
141173
)

0 commit comments

Comments
 (0)