1313
1414from mdio .api .io import to_mdio
1515from mdio .builder .schemas .dtype import ScalarType
16+ from mdio .segy ._disaster_recovery_wrapper import get_header_raw_and_transformed
1617
1718if TYPE_CHECKING :
1819 from segy .arrays import HeaderArray
@@ -81,7 +82,6 @@ def header_scan_worker(
8182
8283 return cast ("HeaderArray" , trace_header )
8384
84-
8585def trace_worker ( # noqa: PLR0913
8686 segy_kw : SegyFileArguments ,
8787 output_path : UPath ,
@@ -120,26 +120,30 @@ def trace_worker( # noqa: PLR0913
120120 zarr_config .set ({"threading.max_workers" : 1 })
121121
122122 live_trace_indexes = local_grid_map [not_null ].tolist ()
123- traces = segy_file .trace [live_trace_indexes ]
124123
125124 header_key = "headers"
126125 raw_header_key = "raw_headers"
127126
127+ # Used to disable the reverse transforms if we aren't going to write the raw headers
128+ do_reverse_transforms = False
129+
128130 # Get subset of the dataset that has not yet been saved
129131 # The headers might not be present in the dataset
130132 worker_variables = [data_variable_name ]
131133 if header_key in dataset .data_vars : # Keeping the `if` here to allow for more worker configurations
132134 worker_variables .append (header_key )
133135 if raw_header_key in dataset .data_vars :
136+
137+ do_reverse_transforms = True
134138 worker_variables .append (raw_header_key )
135139
140+ raw_headers , transformed_headers , traces = get_header_raw_and_transformed (segy_file , live_trace_indexes , do_reverse_transforms = do_reverse_transforms )
136141 ds_to_write = dataset [worker_variables ]
137142
138143 if header_key in worker_variables :
139144 # Create temporary array for headers with the correct shape
140- # TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
141145 tmp_headers = np .zeros_like (dataset [header_key ])
142- tmp_headers [not_null ] = traces . header
146+ tmp_headers [not_null ] = transformed_headers
143147 # Create a new Variable object to avoid copying the temporary array
144148 # The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
145149 # but Xarray appears to be copying memory instead of doing direct assignment.
@@ -150,41 +154,19 @@ def trace_worker( # noqa: PLR0913
150154 attrs = ds_to_write [header_key ].attrs ,
151155 encoding = ds_to_write [header_key ].encoding , # Not strictly necessary, but safer than not doing it.
152156 )
157+ del transformed_headers # Manage memory
153158 if raw_header_key in worker_variables :
154159 tmp_raw_headers = np .zeros_like (dataset [raw_header_key ])
155-
156- # Get the indices where we need to place results
157- live_mask = not_null
158- live_positions = np .where (live_mask .ravel ())[0 ]
159-
160- if len (live_positions ) > 0 :
161- # Calculate byte ranges for headers
162- header_size = 240
163- trace_offset = segy_file .spec .trace .offset
164- trace_itemsize = segy_file .spec .trace .itemsize
165-
166- starts = []
167- ends = []
168- for global_trace_idx in live_trace_indexes :
169- header_start = trace_offset + global_trace_idx * trace_itemsize
170- header_end = header_start + header_size
171- starts .append (header_start )
172- ends .append (header_end )
173-
174- # Capture raw bytes
175- raw_header_bytes = merge_cat_file (segy_file .fs , segy_file .url , starts , ends )
176-
177- # Convert and place results
178- raw_headers_array = np .frombuffer (bytes (raw_header_bytes ), dtype = "|V240" )
179- tmp_raw_headers .ravel ()[live_positions ] = raw_headers_array
160+ tmp_raw_headers [not_null ] = raw_headers .view ("|V240" )
180161
181162 ds_to_write [raw_header_key ] = Variable (
182163 ds_to_write [raw_header_key ].dims ,
183164 tmp_raw_headers ,
184165 attrs = ds_to_write [raw_header_key ].attrs ,
185- encoding = ds_to_write [raw_header_key ].encoding ,
186- )
166+ encoding = ds_to_write [raw_header_key ].encoding , # Not strictly necessary, but safer than not doing it.
167+
187168
169+ del raw_headers # Manage memory
188170 data_variable = ds_to_write [data_variable_name ]
189171 fill_value = _get_fill_value (ScalarType (data_variable .dtype .name ))
190172 tmp_samples = np .full_like (data_variable , fill_value = fill_value )
0 commit comments