Skip to content

Commit 4b80bdd

Browse files
authored
Profiled disaster recovery array (#8)
- Avoids duplicate read regression issue - Implements isolated and testable logic
1 parent 0df4785 commit 4b80bdd

File tree

4 files changed

+505
-32
lines changed

4 files changed

+505
-32
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Consumer-side utility to get both raw and transformed header data with single filesystem read."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
import numpy as np
9+
from segy.file import SegyFile
10+
from segy.transforms import Transform, ByteSwapTransform, IbmFloatTransform
11+
from numpy.typing import NDArray
12+
13+
def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray:
14+
"""Reverse a single transform operation."""
15+
from segy.schema import Endianness
16+
from segy.transforms import ByteSwapTransform
17+
from segy.transforms import IbmFloatTransform
18+
19+
if isinstance(transform, ByteSwapTransform):
20+
# Reverse the endianness conversion
21+
if endianness == Endianness.LITTLE:
22+
return data
23+
24+
reverse_transform = ByteSwapTransform(Endianness.BIG)
25+
return reverse_transform.apply(data)
26+
27+
elif isinstance(transform, IbmFloatTransform): # TODO: This seems incorrect...
28+
# Reverse IBM float conversion
29+
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
30+
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
31+
return reverse_transform.apply(data)
32+
33+
else:
34+
# For unknown transforms, return data unchanged
35+
return data
36+
37+
def get_header_raw_and_transformed(
38+
segy_file: SegyFile,
39+
indices: int | list[int] | NDArray | slice,
40+
do_reverse_transforms: bool = True
41+
) -> tuple[NDArray | None, NDArray, NDArray]:
42+
"""Get both raw and transformed header data.
43+
44+
Args:
45+
segy_file: The SegyFile instance
46+
indices: Which headers to retrieve
47+
do_reverse_transforms: Whether to apply the reverse transform to get raw data
48+
49+
Returns:
50+
Tuple of (raw_headers, transformed_headers, traces)
51+
"""
52+
traces = segy_file.trace[indices]
53+
transformed_headers = traces.header
54+
55+
# Reverse transforms to get raw data
56+
if do_reverse_transforms:
57+
raw_headers = _reverse_transforms(transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness)
58+
else:
59+
raw_headers = None
60+
61+
return raw_headers, transformed_headers, traces
62+
63+
def _reverse_transforms(transformed_data: NDArray, transform_pipeline, endianness: Endianness) -> NDArray:
64+
"""Reverse the transform pipeline to get raw data."""
65+
raw_data = transformed_data.copy() if hasattr(transformed_data, 'copy') else transformed_data
66+
67+
# Apply transforms in reverse order
68+
for transform in reversed(transform_pipeline.transforms):
69+
raw_data = _reverse_single_transform(raw_data, transform, endianness)
70+
71+
return raw_data

src/mdio/segy/_workers.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mdio.api.io import to_mdio
1515
from mdio.builder.schemas.dtype import ScalarType
16+
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed
1617

1718
if TYPE_CHECKING:
1819
from segy.arrays import HeaderArray
@@ -81,7 +82,6 @@ def header_scan_worker(
8182

8283
return cast("HeaderArray", trace_header)
8384

84-
8585
def trace_worker( # noqa: PLR0913
8686
segy_kw: SegyFileArguments,
8787
output_path: UPath,
@@ -120,26 +120,30 @@ def trace_worker( # noqa: PLR0913
120120
zarr_config.set({"threading.max_workers": 1})
121121

122122
live_trace_indexes = local_grid_map[not_null].tolist()
123-
traces = segy_file.trace[live_trace_indexes]
124123

125124
header_key = "headers"
126125
raw_header_key = "raw_headers"
127126

127+
# Used to disable the reverse transforms if we aren't going to write the raw headers
128+
do_reverse_transforms = False
129+
128130
# Get subset of the dataset that has not yet been saved
129131
# The headers might not be present in the dataset
130132
worker_variables = [data_variable_name]
131133
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
132134
worker_variables.append(header_key)
133135
if raw_header_key in dataset.data_vars:
136+
137+
do_reverse_transforms = True
134138
worker_variables.append(raw_header_key)
135139

140+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms)
136141
ds_to_write = dataset[worker_variables]
137142

138143
if header_key in worker_variables:
139144
# Create temporary array for headers with the correct shape
140-
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
141145
tmp_headers = np.zeros_like(dataset[header_key])
142-
tmp_headers[not_null] = traces.header
146+
tmp_headers[not_null] = transformed_headers
143147
# Create a new Variable object to avoid copying the temporary array
144148
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
145149
# but Xarray appears to be copying memory instead of doing direct assignment.
@@ -150,41 +154,19 @@ def trace_worker( # noqa: PLR0913
150154
attrs=ds_to_write[header_key].attrs,
151155
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
152156
)
157+
del transformed_headers # Manage memory
153158
if raw_header_key in worker_variables:
154159
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
155-
156-
# Get the indices where we need to place results
157-
live_mask = not_null
158-
live_positions = np.where(live_mask.ravel())[0]
159-
160-
if len(live_positions) > 0:
161-
# Calculate byte ranges for headers
162-
header_size = 240
163-
trace_offset = segy_file.spec.trace.offset
164-
trace_itemsize = segy_file.spec.trace.itemsize
165-
166-
starts = []
167-
ends = []
168-
for global_trace_idx in live_trace_indexes:
169-
header_start = trace_offset + global_trace_idx * trace_itemsize
170-
header_end = header_start + header_size
171-
starts.append(header_start)
172-
ends.append(header_end)
173-
174-
# Capture raw bytes
175-
raw_header_bytes = merge_cat_file(segy_file.fs, segy_file.url, starts, ends)
176-
177-
# Convert and place results
178-
raw_headers_array = np.frombuffer(bytes(raw_header_bytes), dtype="|V240")
179-
tmp_raw_headers.ravel()[live_positions] = raw_headers_array
160+
tmp_raw_headers[not_null] = raw_headers.view("|V240")
180161

181162
ds_to_write[raw_header_key] = Variable(
182163
ds_to_write[raw_header_key].dims,
183164
tmp_raw_headers,
184165
attrs=ds_to_write[raw_header_key].attrs,
185-
encoding=ds_to_write[raw_header_key].encoding,
186-
)
166+
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
167+
187168

169+
del raw_headers # Manage memory
188170
data_variable = ds_to_write[data_variable_name]
189171
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
190172
tmp_samples = np.full_like(data_variable, fill_value=fill_value)

src/mdio/segy/blocked_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,4 @@ def to_segy(
280280

281281
non_consecutive_axes -= 1
282282

283-
return block_io_records
283+
return block_io_records

0 commit comments

Comments
 (0)