Skip to content

Commit 2511d30

Browse files
committed
Fix memory usage regression
1 parent 3209ed5 commit 2511d30

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

src/mdio/segy/_workers.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
from mdio.core.storage_location import StorageLocation
2323

24+
from xarray import Variable
25+
from zarr.core.config import config as zarr_config
2426

2527
from mdio.constants import UINT32_MAX
2628
from mdio.schemas.v1.dataset_serializer import _get_fill_value
@@ -101,6 +103,12 @@ def trace_worker( # noqa: PLR0913
101103
Returns:
102104
SummaryStatistics object containing statistics about the written traces.
103105
"""
106+
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__MAX_WORKERS`
107+
# environment variable.
108+
# Since the release of the Zarr 3 engine, it will default to many threads.
109+
# This can cause resource contention and unpredicted memory consumption.
110+
zarr_config.set({"threading.max_workers": 1})
111+
104112
if not dataset.trace_mask.any():
105113
return None
106114

@@ -128,13 +136,29 @@ def trace_worker( # noqa: PLR0913
128136
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
129137
tmp_headers = np.zeros_like(dataset[header_key])
130138
tmp_headers[not_null] = traces.header
131-
ds_to_write[header_key][:] = tmp_headers
139+
# Create a new Variable object to avoid copying the temporary array
140+
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
141+
# but Xarray appears to be copying memory instead of doing direct assignment.
142+
# TODO(BrianMichell): #614 Look into this further.
143+
ds_to_write[header_key] = Variable(
144+
ds_to_write[header_key].dims,
145+
tmp_headers,
146+
attrs=ds_to_write[header_key].attrs,
147+
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
148+
)
132149

133150
data_variable = ds_to_write[data_variable_name]
134151
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
135152
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
136153
tmp_samples[not_null] = traces.sample
137-
ds_to_write[data_variable_name][:] = tmp_samples
154+
# Create a new Variable object to avoid copying the temporary array
155+
# TODO(BrianMichell): #614 Look into this further.
156+
ds_to_write[data_variable_name] = Variable(
157+
ds_to_write[data_variable_name].dims,
158+
tmp_samples,
159+
attrs=ds_to_write[data_variable_name].attrs,
160+
encoding=ds_to_write[data_variable_name].encoding, # Not strictly necessary, but safer than not doing it.
161+
)
138162

139163
ds_to_write.to_zarr(output_location.uri, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)
140164

0 commit comments

Comments
 (0)