Skip to content

Commit 8d6e517

Browse files
committed
Provide clean disaster recovery interface
1 parent 7dad895 commit 8d6e517

File tree

2 files changed

+28
-69
lines changed

2 files changed

+28
-69
lines changed

src/mdio/segy/_disaster_recovery_wrapper.py

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,73 +4,32 @@
44

55
from typing import TYPE_CHECKING
66

7-
from segy.schema import Endianness
8-
from segy.transforms import ByteSwapTransform
9-
from segy.transforms import IbmFloatTransform
7+
from copy import deepcopy
8+
import numpy as np
109

1110
if TYPE_CHECKING:
1211
from numpy.typing import NDArray
1312
from segy import SegyFile
14-
from segy.transforms import Transform
15-
from segy.transforms import TransformPipeline
1613

14+
class SegyFileTraceDataWrapper:
1715

18-
def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray:
19-
"""Reverse a single transform operation."""
20-
if isinstance(transform, ByteSwapTransform):
21-
# Reverse the endianness conversion
22-
if endianness == Endianness.LITTLE:
23-
return data
16+
def __init__(self, segy_file: SegyFile, indices: int | list[int] | NDArray | slice):
17+
self.segy_file = segy_file
18+
self.indices = indices
19+
self._header_pipeline = deepcopy(segy_file.accessors.header_decode_pipeline)
20+
segy_file.accessors.header_decode_pipeline.transforms = []
21+
self.traces = segy_file.trace[indices]
2422

25-
reverse_transform = ByteSwapTransform(Endianness.BIG)
26-
return reverse_transform.apply(data)
23+
@property
24+
def header(self):
25+
# The copy is necessary to avoid applying the pipeline to the original header.
26+
return self._header_pipeline.apply(self.traces.header.copy())
2727

28-
# TODO(BrianMichell): #0000 Do we actually need to worry about IBM/IEEE transforms here?
29-
if isinstance(transform, IbmFloatTransform):
30-
# Reverse IBM float conversion
31-
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
32-
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
33-
return reverse_transform.apply(data)
3428

35-
# For unknown transforms, return data unchanged
36-
return data
29+
@property
30+
def raw_header(self):
31+
return np.ascontiguousarray(self.traces.header).view("|V240")
3732

38-
39-
def get_header_raw_and_transformed(
40-
segy_file: SegyFile, indices: int | list[int] | NDArray | slice, do_reverse_transforms: bool = True
41-
) -> tuple[NDArray | None, NDArray, NDArray]:
42-
"""Get both raw and transformed header data.
43-
44-
Args:
45-
segy_file: The SegyFile instance
46-
indices: Which headers to retrieve
47-
do_reverse_transforms: Whether to apply the reverse transform to get raw data
48-
49-
Returns:
50-
Tuple of (raw_headers, transformed_headers, traces)
51-
"""
52-
traces = segy_file.trace[indices]
53-
transformed_headers = traces.header
54-
55-
# Reverse transforms to get raw data
56-
if do_reverse_transforms:
57-
raw_headers = _reverse_transforms(
58-
transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness
59-
)
60-
else:
61-
raw_headers = None
62-
63-
return raw_headers, transformed_headers, traces
64-
65-
66-
def _reverse_transforms(
67-
transformed_data: NDArray, transform_pipeline: TransformPipeline, endianness: Endianness
68-
) -> NDArray:
69-
"""Reverse the transform pipeline to get raw data."""
70-
raw_data = transformed_data.copy() if hasattr(transformed_data, "copy") else transformed_data
71-
72-
# Apply transforms in reverse order
73-
for transform in reversed(transform_pipeline.transforms):
74-
raw_data = _reverse_single_transform(raw_data, transform, endianness)
75-
76-
return raw_data
33+
@property
34+
def sample(self):
35+
return self.traces.sample

src/mdio/segy/_workers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from mdio.api.io import to_mdio
1414
from mdio.builder.schemas.dtype import ScalarType
15-
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed
15+
from mdio.segy._disaster_recovery_wrapper import SegyFileTraceDataWrapper
1616

1717
if TYPE_CHECKING:
1818
from segy.arrays import HeaderArray
@@ -134,17 +134,18 @@ def trace_worker( # noqa: PLR0913
134134
if raw_header_key in dataset.data_vars:
135135
worker_variables.append(raw_header_key)
136136

137-
from copy import deepcopy # TODO: Move to head if we need to copy
138-
header_pipeline = deepcopy(segy_file.accessors.header_decode_pipeline)
139-
segy_file.accessors.header_decode_pipeline.transforms = []
140-
traces = segy_file.trace[live_trace_indexes]
137+
# traces = segy_file.trace[live_trace_indexes]
138+
# Raw headers are not intended to remain as a feature of the SEGY ingestion.
139+
# For that reason, we have wrapped the accessors to provide an interface that can be removed
140+
# and not require additional changes to the below code.
141+
# NOTE: The `raw_header_key` code block should be removed in full as it will become dead code.
142+
traces = SegyFileTraceDataWrapper(segy_file, live_trace_indexes)
141143
ds_to_write = dataset[worker_variables]
142144

143145
if header_key in worker_variables:
144146
# Create temporary array for headers with the correct shape
145147
tmp_headers = np.zeros_like(dataset[header_key])
146-
# tmp_headers[not_null] = transformed_headers
147-
tmp_headers[not_null] = header_pipeline.apply(traces.header.copy())
148+
tmp_headers[not_null] = traces.header
148149
# Create a new Variable object to avoid copying the temporary array
149150
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
150151
# but Xarray appears to be copying memory instead of doing direct assignment.
@@ -155,10 +156,9 @@ def trace_worker( # noqa: PLR0913
155156
attrs=ds_to_write[header_key].attrs,
156157
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
157158
)
158-
# del transformed_headers # Manage memory
159159
if raw_header_key in worker_variables:
160160
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
161-
tmp_raw_headers[not_null] = np.ascontiguousarray(traces.header).view("|V240")
161+
tmp_raw_headers[not_null] = traces.raw_header
162162

163163
ds_to_write[raw_header_key] = Variable(
164164
ds_to_write[raw_header_key].dims,

0 commit comments

Comments
 (0)