1212from segy import SegyFile
1313from segy .arrays import HeaderArray
1414
15- from mdio .api .io import to_mdio
16- from mdio .builder .schemas .dtype import ScalarType
15+ from mdio .api .io import _normalize_storage_options
1716from mdio .segy ._raw_trace_wrapper import SegyFileRawTraceWrapper
1817from mdio .segy .scalar import _get_coordinate_scalar
1918
2221 from segy .config import SegyHeaderOverrides
2322 from segy .schema import SegySpec
2423 from upath import UPath
25- from xarray import Dataset as xr_Dataset
2624 from zarr import Array as zarr_Array
2725
28- from xarray import Variable
26+ from zarr import open_group as zarr_open_group
2927from zarr .core .config import config as zarr_config
3028
3129from mdio .builder .schemas .v1 .stats import CenteredBinHistogram
3230from mdio .builder .schemas .v1 .stats import SummaryStatistics
33- from mdio .builder .xarray_builder import _get_fill_value
3431from mdio .constants import fill_value_map
3532
3633if TYPE_CHECKING :
@@ -100,7 +97,6 @@ def trace_worker( # noqa: PLR0913
10097 data_variable_name : str ,
10198 region : dict [str , slice ],
10299 grid_map : zarr_Array ,
103- dataset : xr_Dataset ,
104100) -> SummaryStatistics | None :
105101 """Writes a subset of traces from a region of the dataset of Zarr file.
106102
@@ -112,7 +108,6 @@ def trace_worker( # noqa: PLR0913
112108 data_variable_name: Name of the data variable to write.
113109 region: Region of the dataset to write to.
114110 grid_map: Zarr array mapping live traces to their positions in the dataset.
115- dataset: Xarray dataset containing the data to write.
116111
117112 Returns:
118113 SummaryStatistics object containing statistics about the written traces.
@@ -135,16 +130,15 @@ def trace_worker( # noqa: PLR0913
135130
136131 live_trace_indexes = local_grid_map [not_null ].tolist ()
137132
133+ # Open the zarr group to write directly
134+ storage_options = _normalize_storage_options (output_path )
135+ zarr_group = zarr_open_group (output_path .as_posix (), mode = "r+" , storage_options = storage_options )
136+
138137 header_key = "headers"
139138 raw_header_key = "raw_headers"
140139
141- # Get subset of the dataset that has not yet been saved
142- # The headers might not be present in the dataset
143- worker_variables = [data_variable_name ]
144- if header_key in dataset .data_vars : # Keeping the `if` here to allow for more worker configurations
145- worker_variables .append (header_key )
146- if raw_header_key in dataset .data_vars :
147- worker_variables .append (raw_header_key )
140+ # Check which variables exist in the zarr store
141+ available_arrays = list (zarr_group .array_keys ())
148142
149143 # traces = segy_file.trace[live_trace_indexes]
150144 # Raw headers are not intended to remain as a feature of the SEGY ingestion.
@@ -153,51 +147,33 @@ def trace_worker( # noqa: PLR0913
153147 # NOTE: The `raw_header_key` code block should be removed in full as it will become dead code.
154148 traces = SegyFileRawTraceWrapper (segy_file , live_trace_indexes )
155149
156- ds_to_write = dataset [worker_variables ]
150+ # Compute slices once (headers exclude sample dimension)
151+ header_region_slices = region_slices [:- 1 ] # Exclude sample dimension
152+
153+ full_shape = tuple (s .stop - s .start for s in region_slices )
154+ header_shape = tuple (s .stop - s .start for s in header_region_slices )
157155
158- if raw_header_key in worker_variables :
159- tmp_raw_headers = np .zeros_like (dataset [raw_header_key ])
156+ # Write raw headers if they exist
157+ # Headers only have spatial dimensions (no sample dimension)
158+ if raw_header_key in available_arrays :
159+ raw_header_array = zarr_group [raw_header_key ]
160+ tmp_raw_headers = np .full (header_shape , raw_header_array .fill_value )
160161 tmp_raw_headers [not_null ] = traces .raw_header
162+ raw_header_array [header_region_slices ] = tmp_raw_headers
161163
162- ds_to_write [raw_header_key ] = Variable (
163- ds_to_write [raw_header_key ].dims ,
164- tmp_raw_headers ,
165- attrs = ds_to_write [raw_header_key ].attrs ,
166- encoding = ds_to_write [raw_header_key ].encoding , # Not strictly necessary, but safer than not doing it.
167- )
168-
169- if header_key in worker_variables :
170- # TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code
171- # https://github.com/TGSAI/mdio-python/issues/584
172- tmp_headers = np .zeros_like (dataset [header_key ])
164+ # Write headers if they exist
165+ # Headers only have spatial dimensions (no sample dimension)
166+ if header_key in available_arrays :
167+ header_array = zarr_group [header_key ]
168+ tmp_headers = np .full (header_shape , header_array .fill_value )
173169 tmp_headers [not_null ] = traces .header
174- # Create a new Variable object to avoid copying the temporary array
175- # The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
176- # but Xarray appears to be copying memory instead of doing direct assignment.
177- # TODO(BrianMichell): #614 Look into this further.
178- # https://github.com/TGSAI/mdio-python/issues/584
179- ds_to_write [header_key ] = Variable (
180- ds_to_write [header_key ].dims ,
181- tmp_headers ,
182- attrs = ds_to_write [header_key ].attrs ,
183- encoding = ds_to_write [header_key ].encoding , # Not strictly necessary, but safer than not doing it.
184- )
185-
186- data_variable = ds_to_write [data_variable_name ]
187- fill_value = _get_fill_value (ScalarType (data_variable .dtype .name ))
188- tmp_samples = np .full_like (data_variable , fill_value = fill_value )
189- tmp_samples [not_null ] = traces .sample
170+ header_array [header_region_slices ] = tmp_headers
190171
191- # TODO(BrianMichell): #614 Look into this further.
192- # https://github.com/TGSAI/mdio-python/issues/584
193- ds_to_write [data_variable_name ] = Variable (
194- ds_to_write [data_variable_name ].dims ,
195- tmp_samples ,
196- attrs = ds_to_write [data_variable_name ].attrs ,
197- encoding = ds_to_write [data_variable_name ].encoding , # Not strictly necessary, but safer than not doing it.
198- )
199-
200- to_mdio (ds_to_write , output_path = output_path , region = region , mode = "r+" )
172+ # Write the data variable
173+ data_array = zarr_group [data_variable_name ]
174+ tmp_samples = np .full (full_shape , data_array .fill_value )
175+ tmp_samples [not_null ] = traces .sample
176+ data_array [region_slices ] = tmp_samples
201177
202178 nonzero_samples = np .ma .masked_values (traces .sample , 0 , copy = False )
203179 histogram = CenteredBinHistogram (bin_centers = [], counts = [])
0 commit comments