Skip to content

Commit bd5cb6c

Browse files
authored
Dr with modifications (#9)
* Attempt to use view * Add hex-dump and MDIO output reproducer * Fixes * Cleanup * Provide clean disaster recovery interface * Begin work on tests * Fix flattening issue * Push for debugging * Numpy updates * Testing * Working end-to-end examples * Cleanup * Bandaid fix * linting pass 1 * Fix logic issue * Use wrapper class * Precommit * Remove external debugging code * Remove debug code * Remove errant numpy additon to pyproject toml * Fix uv lock to mainline * Pre-commit * Remove raw field additions. Depends on segy >= 0.5.1 * Removed raw byte inserts (#10) * Update Xarray api access (TGSAI#688) * Reimplement disaster recovery logic * Ensure getting true raw bytes for DR array * Linting * Add v2 issue check * Fix pre-commit * Profiled disaster recovery array (#8) - Avoids duplicate read regression issue - Implements isolated and testable logic * Fix unclosed parenthesis * Linting * Test DR compatibility with all tested schemas * Fix missing test fixture error * Suppress unused linting error * Attempt to use view * Add hex-dump and MDIO output reproducer * Fixes * Cleanup * Provide clean disaster recovery interface * Begin work on tests * Fix flattening issue * Push for debugging * Numpy updates * Testing * Working end-to-end examples * Cleanup * Bandaid fix * linting pass 1 * Fix logic issue * Use wrapper class * Precommit * Remove external debugging code * Remove debug code * Remove errant numpy additon to pyproject toml * Fix uv lock to mainline * Pre-commit * Removed raw byte inserts Removed the insertions of raw bytes into the raw bytes Variable. This issue will be addressed in tgsai/segy release >0.5.1 * Use new segy API calls * Updates to get working * Use released version * Linting
1 parent 9b42fbc commit bd5cb6c

File tree

8 files changed

+251
-281
lines changed

8 files changed

+251
-281
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ dependencies = [
2626
"psutil>=7.0.0",
2727
"pydantic>=2.11.9",
2828
"rich>=14.1.0",
29-
"segy>=0.5.0",
29+
"segy>=0.5.1.post1",
3030
"tqdm>=4.67.1",
3131
"universal-pathlib>=0.2.6",
32-
"xarray>=2025.9.0",
32+
"xarray>=2025.9.1",
3333
"zarr>=3.1.3",
3434
]
3535

src/mdio/api/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from upath import UPath
1111
from xarray import Dataset as xr_Dataset
1212
from xarray import open_zarr as xr_open_zarr
13-
from xarray.backends.api import to_zarr as xr_to_zarr
13+
from xarray.backends.writers import to_zarr as xr_to_zarr
1414

1515
from mdio.constants import ZarrFormat
1616
from mdio.core.zarr_io import zarr_warnings_suppress_unstable_structs_v3

src/mdio/segy/_disaster_recovery_wrapper.py

Lines changed: 17 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,73 +4,31 @@
44

55
from typing import TYPE_CHECKING
66

7-
from segy.schema import Endianness
8-
from segy.transforms import ByteSwapTransform
9-
from segy.transforms import IbmFloatTransform
107

118
if TYPE_CHECKING:
129
from numpy.typing import NDArray
1310
from segy import SegyFile
14-
from segy.transforms import Transform
15-
from segy.transforms import TransformPipeline
1611

1712

18-
def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray:
19-
"""Reverse a single transform operation."""
20-
if isinstance(transform, ByteSwapTransform):
21-
# Reverse the endianness conversion
22-
if endianness == Endianness.LITTLE:
23-
return data
13+
class SegyFileTraceDataWrapper:
14+
def __init__(self, segy_file: SegyFile, indices: int | list[int] | NDArray | slice):
15+
self.segy_file = segy_file
16+
self.indices = indices
2417

25-
reverse_transform = ByteSwapTransform(Endianness.BIG)
26-
return reverse_transform.apply(data)
18+
self.idx = self.segy_file.trace.normalize_and_validate_query(self.indices)
19+
self.traces = self.segy_file.trace.fetch(self.idx, raw=True)
2720

28-
# TODO(BrianMichell): #0000 Do we actually need to worry about IBM/IEEE transforms here?
29-
if isinstance(transform, IbmFloatTransform):
30-
# Reverse IBM float conversion
31-
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
32-
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
33-
return reverse_transform.apply(data)
21+
self.raw_view = self.traces.view(self.segy_file.spec.trace.dtype)
22+
self.decoded_traces = self.segy_file.accessors.trace_decode_pipeline.apply(self.raw_view.copy())
3423

35-
# For unknown transforms, return data unchanged
36-
return data
24+
@property
25+
def raw_header(self) -> NDArray:
26+
return self.raw_view.header.view("|V240")
3727

28+
@property
29+
def header(self) -> NDArray:
30+
return self.decoded_traces.header
3831

39-
def get_header_raw_and_transformed(
40-
segy_file: SegyFile, indices: int | list[int] | NDArray | slice, do_reverse_transforms: bool = True
41-
) -> tuple[NDArray | None, NDArray, NDArray]:
42-
"""Get both raw and transformed header data.
43-
44-
Args:
45-
segy_file: The SegyFile instance
46-
indices: Which headers to retrieve
47-
do_reverse_transforms: Whether to apply the reverse transform to get raw data
48-
49-
Returns:
50-
Tuple of (raw_headers, transformed_headers, traces)
51-
"""
52-
traces = segy_file.trace[indices]
53-
transformed_headers = traces.header
54-
55-
# Reverse transforms to get raw data
56-
if do_reverse_transforms:
57-
raw_headers = _reverse_transforms(
58-
transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness
59-
)
60-
else:
61-
raw_headers = None
62-
63-
return raw_headers, transformed_headers, traces
64-
65-
66-
def _reverse_transforms(
67-
transformed_data: NDArray, transform_pipeline: TransformPipeline, endianness: Endianness
68-
) -> NDArray:
69-
"""Reverse the transform pipeline to get raw data."""
70-
raw_data = transformed_data.copy() if hasattr(transformed_data, "copy") else transformed_data
71-
72-
# Apply transforms in reverse order
73-
for transform in reversed(transform_pipeline.transforms):
74-
raw_data = _reverse_single_transform(raw_data, transform, endianness)
75-
76-
return raw_data
32+
@property
33+
def sample(self) -> NDArray:
34+
return self.decoded_traces.sample

src/mdio/segy/_workers.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from mdio.api.io import to_mdio
1414
from mdio.builder.schemas.dtype import ScalarType
15-
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed
15+
from mdio.segy._disaster_recovery_wrapper import SegyFileTraceDataWrapper
1616

1717
if TYPE_CHECKING:
1818
from segy.arrays import HeaderArray
@@ -126,28 +126,39 @@ def trace_worker( # noqa: PLR0913
126126
header_key = "headers"
127127
raw_header_key = "raw_headers"
128128

129-
# Used to disable the reverse transforms if we aren't going to write the raw headers
130-
do_reverse_transforms = False
131-
132129
# Get subset of the dataset that has not yet been saved
133130
# The headers might not be present in the dataset
134131
worker_variables = [data_variable_name]
135132
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
136133
worker_variables.append(header_key)
137134
if raw_header_key in dataset.data_vars:
138-
do_reverse_transforms = True
139135
worker_variables.append(raw_header_key)
140136

141-
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
142-
segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms
143-
)
137+
# traces = segy_file.trace[live_trace_indexes]
138+
# Raw headers are not intended to remain as a feature of the SEGY ingestion.
139+
# For that reason, we have wrapped the accessors to provide an interface that can be removed
140+
# and not require additional changes to the below code.
141+
# NOTE: The `raw_header_key` code block should be removed in full as it will become dead code.
142+
traces = SegyFileTraceDataWrapper(segy_file, live_trace_indexes)
143+
144144
ds_to_write = dataset[worker_variables]
145145

146+
if raw_header_key in worker_variables:
147+
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
148+
tmp_raw_headers[not_null] = traces.raw_header
149+
150+
ds_to_write[raw_header_key] = Variable(
151+
ds_to_write[raw_header_key].dims,
152+
tmp_raw_headers,
153+
attrs=ds_to_write[raw_header_key].attrs,
154+
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
155+
)
156+
146157
if header_key in worker_variables:
147158
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code
148159
# https://github.com/TGSAI/mdio-python/issues/584
149160
tmp_headers = np.zeros_like(dataset[header_key])
150-
tmp_headers[not_null] = transformed_headers
161+
tmp_headers[not_null] = traces.header
151162
# Create a new Variable object to avoid copying the temporary array
152163
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
153164
# but Xarray appears to be copying memory instead of doing direct assignment.
@@ -159,19 +170,7 @@ def trace_worker( # noqa: PLR0913
159170
attrs=ds_to_write[header_key].attrs,
160171
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
161172
)
162-
del transformed_headers # Manage memory
163-
if raw_header_key in worker_variables:
164-
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
165-
tmp_raw_headers[not_null] = raw_headers.view("|V240")
166-
167-
ds_to_write[raw_header_key] = Variable(
168-
ds_to_write[raw_header_key].dims,
169-
tmp_raw_headers,
170-
attrs=ds_to_write[raw_header_key].attrs,
171-
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
172-
)
173173

174-
del raw_headers # Manage memory
175174
data_variable = ds_to_write[data_variable_name]
176175
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
177176
tmp_samples = np.full_like(data_variable, fill_value=fill_value)

src/mdio/segy/creation.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,38 @@
2828
logger = logging.getLogger(__name__)
2929

3030

31+
def _filter_raw_unspecified_fields(headers: NDArray) -> NDArray:
32+
"""Filter out __MDIO_RAW_UNSPECIFIED_Field_* fields from headers array.
33+
34+
These fields are added during SEGY import to preserve raw header bytes,
35+
but they cause dtype mismatches during export. This function removes them.
36+
37+
Args:
38+
headers: Header array that may contain raw unspecified fields.
39+
40+
Returns:
41+
Header array with raw unspecified fields removed.
42+
"""
43+
if headers.dtype.names is None:
44+
return headers
45+
46+
# Find field names that don't start with __MDIO_RAW_UNSPECIFIED_
47+
field_names = [name for name in headers.dtype.names if not name.startswith("__MDIO_RAW_UNSPECIFIED_")]
48+
49+
if len(field_names) == len(headers.dtype.names):
50+
# No raw unspecified fields found, return as-is
51+
return headers
52+
53+
# Create new structured array with only the non-raw fields
54+
new_dtype = [(name, headers.dtype.fields[name][0]) for name in field_names]
55+
filtered_headers = np.empty(headers.shape, dtype=new_dtype)
56+
57+
for name in field_names:
58+
filtered_headers[name] = headers[name]
59+
60+
return filtered_headers
61+
62+
3163
def make_segy_factory(spec: SegySpec, binary_header: dict[str, int]) -> SegyFactory:
3264
"""Generate SEG-Y factory from MDIO metadata."""
3365
sample_interval = binary_header["sample_interval"]
@@ -167,7 +199,9 @@ def serialize_to_segy_stack( # noqa: PLR0913
167199
samples = samples[live_mask]
168200
headers = headers[live_mask]
169201

170-
buffer = segy_factory.create_traces(headers, samples)
202+
# Filter out raw unspecified fields that cause dtype mismatches
203+
filtered_headers = _filter_raw_unspecified_fields(headers)
204+
buffer = segy_factory.create_traces(filtered_headers, samples)
171205

172206
global_index = block_start[0]
173207
record_id_str = str(global_index)
@@ -199,7 +233,9 @@ def serialize_to_segy_stack( # noqa: PLR0913
199233
rec_samples = samples[rec_index][rec_live_mask]
200234
rec_headers = headers[rec_index][rec_live_mask]
201235

202-
buffer = segy_factory.create_traces(rec_headers, rec_samples)
236+
# Filter out raw unspecified fields that cause dtype mismatches
237+
filtered_headers = _filter_raw_unspecified_fields(rec_headers)
238+
buffer = segy_factory.create_traces(filtered_headers, rec_samples)
203239

204240
global_index = tuple(block_start[i] + rec_index[i] for i in range(record_ndim))
205241
record_id_str = "/".join(map(str, global_index))

tests/integration/test_segy_import_export_masked.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,16 @@ def generate_selection_mask(selection_conf: SelectionMaskConfig, grid_conf: Grid
282282

283283

284284
@pytest.fixture
285-
def export_masked_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
285+
def export_masked_path(tmp_path_factory: pytest.TempPathFactory, raw_headers_env: None) -> Path: # noqa: ARG001
286286
"""Fixture that generates temp directory for export tests."""
287+
# Create path suffix based on current raw headers environment variable
288+
# raw_headers_env dependency ensures the environment variable is set before this runs
289+
raw_headers_enabled = os.getenv("MDIO__DO_RAW_HEADERS") == "1"
290+
path_suffix = "with_raw_headers" if raw_headers_enabled else "without_raw_headers"
291+
287292
if DEBUG_MODE:
288-
return Path("TMP/export_masked")
289-
return tmp_path_factory.getbasetemp() / "export_masked"
293+
return Path(f"TMP/export_masked_{path_suffix}")
294+
return tmp_path_factory.getbasetemp() / f"export_masked_{path_suffix}"
290295

291296

292297
@pytest.fixture
@@ -300,9 +305,39 @@ def raw_headers_env(request: pytest.FixtureRequest) -> None:
300305

301306
yield
302307

303-
# Cleanup after test
308+
# Cleanup after test - both environment variable and template state
304309
os.environ.pop("MDIO__DO_RAW_HEADERS", None)
305310

311+
# Clean up any template modifications to ensure test isolation
312+
registry = TemplateRegistry.get_instance()
313+
314+
# Reset any templates that might have been modified with raw headers
315+
template_names = [
316+
"PostStack2DTime",
317+
"PostStack3DTime",
318+
"PreStackCdpOffsetGathers2DTime",
319+
"PreStackCdpOffsetGathers3DTime",
320+
"PreStackShotGathers2DTime",
321+
"PreStackShotGathers3DTime",
322+
"PreStackCocaGathers3DTime",
323+
]
324+
325+
for template_name in template_names:
326+
try:
327+
template = registry.get(template_name)
328+
# Remove raw headers enhancement if present
329+
if hasattr(template, "_mdio_raw_headers_enhanced"):
330+
delattr(template, "_mdio_raw_headers_enhanced")
331+
# The enhancement is applied by monkey-patching _add_variables
332+
# We need to restore it to the original method from the class
333+
# Since we can't easily restore the exact original, we'll get a fresh instance
334+
template_class = type(template)
335+
if hasattr(template_class, "_add_variables"):
336+
template._add_variables = template_class._add_variables.__get__(template, template_class)
337+
except KeyError:
338+
# Template not found, skip
339+
continue
340+
306341

307342
@pytest.mark.parametrize(
308343
"test_conf",
@@ -471,3 +506,69 @@ def test_export_masked(
471506
# https://github.com/TGSAI/mdio-python/issues/610
472507
assert_array_equal(actual_sgy.trace[:].header, expected_sgy.trace[expected_trc_idx].header)
473508
assert_array_equal(actual_sgy.trace[:].sample, expected_sgy.trace[expected_trc_idx].sample)
509+
510+
def test_raw_headers_byte_preservation(
511+
self,
512+
test_conf: MaskedExportConfig,
513+
export_masked_path: Path,
514+
raw_headers_env: None, # noqa: ARG002
515+
) -> None:
516+
"""Test that raw headers are preserved byte-for-byte when MDIO__DO_RAW_HEADERS=1."""
517+
grid_conf, segy_factory_conf, _, _ = test_conf
518+
segy_path = export_masked_path / f"{grid_conf.name}.sgy"
519+
mdio_path = export_masked_path / f"{grid_conf.name}.mdio"
520+
521+
# Open MDIO dataset
522+
ds = open_mdio(mdio_path)
523+
524+
# Check if raw_headers should exist based on environment variable
525+
has_raw_headers = "raw_headers" in ds.data_vars
526+
if os.getenv("MDIO__DO_RAW_HEADERS") == "1":
527+
assert has_raw_headers, "raw_headers should be present when MDIO__DO_RAW_HEADERS=1"
528+
else:
529+
assert not has_raw_headers, f"raw_headers should not be present when MDIO__DO_RAW_HEADERS is not set\n {ds}"
530+
return # Exit early if raw_headers are not expected
531+
532+
# Get data (only if raw_headers exist)
533+
raw_headers_data = ds.raw_headers.values
534+
trace_mask = ds.trace_mask.values
535+
536+
# Verify 240-byte headers
537+
assert raw_headers_data.dtype.itemsize == 240, (
538+
f"Expected 240-byte headers, got {raw_headers_data.dtype.itemsize}"
539+
)
540+
541+
# Read raw bytes directly from SEG-Y file
542+
def read_segy_trace_header(trace_index: int) -> bytes:
543+
"""Read 240-byte trace header directly from SEG-Y file."""
544+
# with open(segy_path, "rb") as f:
545+
with Path.open(segy_path, "rb") as f:
546+
# Skip text (3200) + binary (400) headers = 3600 bytes
547+
f.seek(3600)
548+
# Each trace: 240 byte header + (num_samples * 4) byte samples
549+
trace_size = 240 + (segy_factory_conf.num_samples * 4)
550+
trace_offset = trace_index * trace_size
551+
f.seek(trace_offset, 1) # Seek relative to current position
552+
return f.read(240)
553+
554+
# Compare all valid traces byte-by-byte
555+
segy_trace_idx = 0
556+
flat_mask = trace_mask.ravel()
557+
flat_raw_headers = raw_headers_data.ravel() # Flatten to 1D array of 240-byte header records
558+
559+
for grid_idx in range(flat_mask.size):
560+
if not flat_mask[grid_idx]:
561+
print(f"Skipping trace {grid_idx} because it is masked")
562+
continue
563+
564+
# Get MDIO header as bytes - convert single header record to bytes
565+
header_record = flat_raw_headers[grid_idx]
566+
mdio_header_bytes = np.frombuffer(header_record.tobytes(), dtype=np.uint8)
567+
568+
# Get SEG-Y header as raw bytes directly from file
569+
segy_raw_header_bytes = read_segy_trace_header(segy_trace_idx)
570+
segy_header_bytes = np.frombuffer(segy_raw_header_bytes, dtype=np.uint8)
571+
572+
assert_array_equal(mdio_header_bytes, segy_header_bytes)
573+
574+
segy_trace_idx += 1

0 commit comments

Comments
 (0)