1010import numpy as np
1111from segy import SegyFile
1212
13+ from mdio .schemas import ScalarType
14+
1315if TYPE_CHECKING :
1416 from segy .arrays import HeaderArray
1517 from segy .config import SegySettings
1921
2022 from mdio .core .storage_location import StorageLocation
2123
24+ from xarray import Variable
25+ from zarr .core .config import config as zarr_config
26+
2227from mdio .constants import UINT32_MAX
28+ from mdio .schemas .v1 .dataset_serializer import _get_fill_value
2329from mdio .schemas .v1 .stats import CenteredBinHistogram
2430from mdio .schemas .v1 .stats import SummaryStatistics
2531
@@ -97,6 +103,12 @@ def trace_worker( # noqa: PLR0913
97103 Returns:
98104 SummaryStatistics object containing statistics about the written traces.
99105 """
106+ # Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__MAX_WORKERS`
107+ # environment variable.
108+ # Since the release of the Zarr 3 engine, it will default to many threads.
109+ # This can cause resource contention and unpredicted memory consumption.
110+ zarr_config .set ({"threading.max_workers" : 1 })
111+
100112 if not dataset .trace_mask .any ():
101113 return None
102114
@@ -109,33 +121,53 @@ def trace_worker( # noqa: PLR0913
109121 live_trace_indexes = grid_map [not_null ].tolist ()
110122 traces = segy_file .trace [live_trace_indexes ]
111123
124+ header_key = "headers"
125+
112126 # Get subset of the dataset that has not yet been saved
113127 # The headers might not be present in the dataset
114- # TODO(Dmitriy Repin): Check, should we overwrite the 'dataset' instead to save the memory
115- # https://github.com/TGSAI/mdio-python/issues/584
116- if "headers" in dataset .data_vars :
117- ds_to_write = dataset [[data_variable_name , "headers" ]]
118- ds_to_write = ds_to_write .reset_coords ()
119-
120- ds_to_write ["headers" ].data [not_null ] = traces .header
121- ds_to_write ["headers" ].data [~ not_null ] = 0
122- else :
123- ds_to_write = dataset [[data_variable_name ]]
124- ds_to_write = ds_to_write .reset_coords ()
125-
126- ds_to_write [data_variable_name ].data [not_null ] = traces .sample
128+ worker_variables = [data_variable_name ]
129+ if header_key in dataset .data_vars : # Keeping the `if` here to allow for more worker configurations
130+ worker_variables .append (header_key )
131+
132+ ds_to_write = dataset [worker_variables ]
133+
134+ if header_key in worker_variables :
135+ # Create temporary array for headers with the correct shape
136+ # TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
137+ tmp_headers = np .zeros_like (dataset [header_key ])
138+ tmp_headers [not_null ] = traces .header
139+ # Create a new Variable object to avoid copying the temporary array
140+ # The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
141+ # but Xarray appears to be copying memory instead of doing direct assignment.
142+ # TODO(BrianMichell): #614 Look into this further.
143+ ds_to_write [header_key ] = Variable (
144+ ds_to_write [header_key ].dims ,
145+ tmp_headers ,
146+ attrs = ds_to_write [header_key ].attrs ,
147+ encoding = ds_to_write [header_key ].encoding , # Not strictly necessary, but safer than not doing it.
148+ )
149+
150+ data_variable = ds_to_write [data_variable_name ]
151+ fill_value = _get_fill_value (ScalarType (data_variable .dtype .name ))
152+ tmp_samples = np .full_like (data_variable , fill_value = fill_value )
153+ tmp_samples [not_null ] = traces .sample
154+ # Create a new Variable object to avoid copying the temporary array
155+ # TODO(BrianMichell): #614 Look into this further.
156+ ds_to_write [data_variable_name ] = Variable (
157+ ds_to_write [data_variable_name ].dims ,
158+ tmp_samples ,
159+ attrs = ds_to_write [data_variable_name ].attrs ,
160+ encoding = ds_to_write [data_variable_name ].encoding , # Not strictly necessary, but safer than not doing it.
161+ )
127162
128- out_path = output_location .uri
129- ds_to_write .to_zarr (out_path , region = region , mode = "r+" , write_empty_chunks = False , zarr_format = 2 )
163+ ds_to_write .to_zarr (output_location .uri , region = region , mode = "r+" , write_empty_chunks = False , zarr_format = 2 )
130164
131165 histogram = CenteredBinHistogram (bin_centers = [], counts = [])
132166 return SummaryStatistics (
133167 count = traces .sample .size ,
134168 min = traces .sample .min (),
135169 max = traces .sample .max (),
136170 sum = traces .sample .sum (),
137- # TODO(Altay): Look at how to do the sum squares statistic correctly
138- # https://github.com/TGSAI/mdio-python/issues/581
139171 sum_squares = (traces .sample ** 2 ).sum (),
140172 histogram = histogram ,
141173 )
0 commit comments