1313
1414from mdio .api .io import to_mdio
1515from mdio .builder .schemas .dtype import ScalarType
16+ from mdio .segy ._disaster_recovery_wrapper import get_header_raw_and_transformed
1617
1718if TYPE_CHECKING :
1819 from segy .arrays import HeaderArray
@@ -81,7 +82,6 @@ def header_scan_worker(
8182
8283 return cast ("HeaderArray" , trace_header )
8384
84-
8585def trace_worker ( # noqa: PLR0913
8686 segy_kw : SegyFileArguments ,
8787 output_path : UPath ,
@@ -122,26 +122,30 @@ def trace_worker( # noqa: PLR0913
122122 zarr_config .set ({"threading.max_workers" : 1 })
123123
124124 live_trace_indexes = local_grid_map [not_null ].tolist ()
125- traces = segy_file .trace [live_trace_indexes ]
126125
127126 header_key = "headers"
128127 raw_header_key = "raw_headers"
129128
129+ # Used to disable the reverse transforms if we aren't going to write the raw headers
130+ do_reverse_transforms = False
131+
130132 # Get subset of the dataset that has not yet been saved
131133 # The headers might not be present in the dataset
132134 worker_variables = [data_variable_name ]
133135 if header_key in dataset .data_vars : # Keeping the `if` here to allow for more worker configurations
134136 worker_variables .append (header_key )
135137 if raw_header_key in dataset .data_vars :
138+
139+ do_reverse_transforms = True
136140 worker_variables .append (raw_header_key )
137141
142+ raw_headers , transformed_headers , traces = get_header_raw_and_transformed (segy_file , live_trace_indexes , do_reverse_transforms = do_reverse_transforms )
138143 ds_to_write = dataset [worker_variables ]
139144
140145 if header_key in worker_variables :
141146 # Create temporary array for headers with the correct shape
142- # TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
143147 tmp_headers = np .zeros_like (dataset [header_key ])
144- tmp_headers [not_null ] = traces . header
148+ tmp_headers [not_null ] = transformed_headers
145149 # Create a new Variable object to avoid copying the temporary array
146150 # The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
147151 # but Xarray appears to be copying memory instead of doing direct assignment.
@@ -152,41 +156,19 @@ def trace_worker( # noqa: PLR0913
152156 attrs = ds_to_write [header_key ].attrs ,
153157 encoding = ds_to_write [header_key ].encoding , # Not strictly necessary, but safer than not doing it.
154158 )
159+ del transformed_headers # Manage memory
155160 if raw_header_key in worker_variables :
156161 tmp_raw_headers = np .zeros_like (dataset [raw_header_key ])
157-
158- # Get the indices where we need to place results
159- live_mask = not_null
160- live_positions = np .where (live_mask .ravel ())[0 ]
161-
162- if len (live_positions ) > 0 :
163- # Calculate byte ranges for headers
164- header_size = 240
165- trace_offset = segy_file .spec .trace .offset
166- trace_itemsize = segy_file .spec .trace .itemsize
167-
168- starts = []
169- ends = []
170- for global_trace_idx in live_trace_indexes :
171- header_start = trace_offset + global_trace_idx * trace_itemsize
172- header_end = header_start + header_size
173- starts .append (header_start )
174- ends .append (header_end )
175-
176- # Capture raw bytes
177- raw_header_bytes = merge_cat_file (segy_file .fs , segy_file .url , starts , ends )
178-
179- # Convert and place results
180- raw_headers_array = np .frombuffer (bytes (raw_header_bytes ), dtype = "|V240" )
181- tmp_raw_headers .ravel ()[live_positions ] = raw_headers_array
162+ tmp_raw_headers [not_null ] = raw_headers .view ("|V240" )
182163
183164 ds_to_write [raw_header_key ] = Variable (
184165 ds_to_write [raw_header_key ].dims ,
185166 tmp_raw_headers ,
186167 attrs = ds_to_write [raw_header_key ].attrs ,
187- encoding = ds_to_write [raw_header_key ].encoding ,
188- )
168+ encoding = ds_to_write [raw_header_key ].encoding , # Not strictly necessary, but safer than not doing it.
169+
189170
171+ del raw_headers # Manage memory
190172 data_variable = ds_to_write [data_variable_name ]
191173 fill_value = _get_fill_value (ScalarType (data_variable .dtype .name ))
192174 tmp_samples = np .full_like (data_variable , fill_value = fill_value )
0 commit comments