Skip to content

Commit e00d6f4

Browse files
BrianMichelltasansal
authored andcommitted
Profiled disaster recovery array (#8)
- Avoids duplicate read regression issue - Implements isolated and testable logic
1 parent 87aeffb commit e00d6f4

File tree

4 files changed

+505
-32
lines changed

4 files changed

+505
-32
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Consumer-side utility to get both raw and transformed header data with single filesystem read."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
import numpy as np
9+
from segy.file import SegyFile
10+
from segy.transforms import Transform, ByteSwapTransform, IbmFloatTransform
11+
from numpy.typing import NDArray
12+
13+
def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray:
14+
"""Reverse a single transform operation."""
15+
from segy.schema import Endianness
16+
from segy.transforms import ByteSwapTransform
17+
from segy.transforms import IbmFloatTransform
18+
19+
if isinstance(transform, ByteSwapTransform):
20+
# Reverse the endianness conversion
21+
if endianness == Endianness.LITTLE:
22+
return data
23+
24+
reverse_transform = ByteSwapTransform(Endianness.BIG)
25+
return reverse_transform.apply(data)
26+
27+
elif isinstance(transform, IbmFloatTransform): # TODO: This seems incorrect...
28+
# Reverse IBM float conversion
29+
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
30+
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
31+
return reverse_transform.apply(data)
32+
33+
else:
34+
# For unknown transforms, return data unchanged
35+
return data
36+
37+
def get_header_raw_and_transformed(
38+
segy_file: SegyFile,
39+
indices: int | list[int] | NDArray | slice,
40+
do_reverse_transforms: bool = True
41+
) -> tuple[NDArray | None, NDArray, NDArray]:
42+
"""Get both raw and transformed header data.
43+
44+
Args:
45+
segy_file: The SegyFile instance
46+
indices: Which headers to retrieve
47+
do_reverse_transforms: Whether to apply the reverse transform to get raw data
48+
49+
Returns:
50+
Tuple of (raw_headers, transformed_headers, traces)
51+
"""
52+
traces = segy_file.trace[indices]
53+
transformed_headers = traces.header
54+
55+
# Reverse transforms to get raw data
56+
if do_reverse_transforms:
57+
raw_headers = _reverse_transforms(transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness)
58+
else:
59+
raw_headers = None
60+
61+
return raw_headers, transformed_headers, traces
62+
63+
def _reverse_transforms(transformed_data: NDArray, transform_pipeline, endianness: Endianness) -> NDArray:
64+
"""Reverse the transform pipeline to get raw data."""
65+
raw_data = transformed_data.copy() if hasattr(transformed_data, 'copy') else transformed_data
66+
67+
# Apply transforms in reverse order
68+
for transform in reversed(transform_pipeline.transforms):
69+
raw_data = _reverse_single_transform(raw_data, transform, endianness)
70+
71+
return raw_data

src/mdio/segy/_workers.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mdio.api.io import to_mdio
1515
from mdio.builder.schemas.dtype import ScalarType
16+
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed
1617

1718
if TYPE_CHECKING:
1819
from segy.arrays import HeaderArray
@@ -81,7 +82,6 @@ def header_scan_worker(
8182

8283
return cast("HeaderArray", trace_header)
8384

84-
8585
def trace_worker( # noqa: PLR0913
8686
segy_kw: SegyFileArguments,
8787
output_path: UPath,
@@ -122,26 +122,30 @@ def trace_worker( # noqa: PLR0913
122122
zarr_config.set({"threading.max_workers": 1})
123123

124124
live_trace_indexes = local_grid_map[not_null].tolist()
125-
traces = segy_file.trace[live_trace_indexes]
126125

127126
header_key = "headers"
128127
raw_header_key = "raw_headers"
129128

129+
# Used to disable the reverse transforms if we aren't going to write the raw headers
130+
do_reverse_transforms = False
131+
130132
# Get subset of the dataset that has not yet been saved
131133
# The headers might not be present in the dataset
132134
worker_variables = [data_variable_name]
133135
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
134136
worker_variables.append(header_key)
135137
if raw_header_key in dataset.data_vars:
138+
139+
do_reverse_transforms = True
136140
worker_variables.append(raw_header_key)
137141

142+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms)
138143
ds_to_write = dataset[worker_variables]
139144

140145
if header_key in worker_variables:
141146
# Create temporary array for headers with the correct shape
142-
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
143147
tmp_headers = np.zeros_like(dataset[header_key])
144-
tmp_headers[not_null] = traces.header
148+
tmp_headers[not_null] = transformed_headers
145149
# Create a new Variable object to avoid copying the temporary array
146150
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
147151
# but Xarray appears to be copying memory instead of doing direct assignment.
@@ -152,41 +156,19 @@ def trace_worker( # noqa: PLR0913
152156
attrs=ds_to_write[header_key].attrs,
153157
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
154158
)
159+
del transformed_headers # Manage memory
155160
if raw_header_key in worker_variables:
156161
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
157-
158-
# Get the indices where we need to place results
159-
live_mask = not_null
160-
live_positions = np.where(live_mask.ravel())[0]
161-
162-
if len(live_positions) > 0:
163-
# Calculate byte ranges for headers
164-
header_size = 240
165-
trace_offset = segy_file.spec.trace.offset
166-
trace_itemsize = segy_file.spec.trace.itemsize
167-
168-
starts = []
169-
ends = []
170-
for global_trace_idx in live_trace_indexes:
171-
header_start = trace_offset + global_trace_idx * trace_itemsize
172-
header_end = header_start + header_size
173-
starts.append(header_start)
174-
ends.append(header_end)
175-
176-
# Capture raw bytes
177-
raw_header_bytes = merge_cat_file(segy_file.fs, segy_file.url, starts, ends)
178-
179-
# Convert and place results
180-
raw_headers_array = np.frombuffer(bytes(raw_header_bytes), dtype="|V240")
181-
tmp_raw_headers.ravel()[live_positions] = raw_headers_array
162+
tmp_raw_headers[not_null] = raw_headers.view("|V240")
182163

183164
ds_to_write[raw_header_key] = Variable(
184165
ds_to_write[raw_header_key].dims,
185166
tmp_raw_headers,
186167
attrs=ds_to_write[raw_header_key].attrs,
187-
encoding=ds_to_write[raw_header_key].encoding,
188-
)
168+
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
169+
189170

171+
del raw_headers # Manage memory
190172
data_variable = ds_to_write[data_variable_name]
191173
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
192174
tmp_samples = np.full_like(data_variable, fill_value=fill_value)

src/mdio/segy/blocked_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,4 @@ def to_segy(
280280

281281
non_consecutive_axes -= 1
282282

283-
return block_io_records
283+
return block_io_records

0 commit comments

Comments
 (0)