Skip to content

Commit 90e7cdb

Browse files
committed
rename trace wrapper and do lazy decoding
1 parent 116f625 commit 90e7cdb

File tree

5 files changed

+63
-41
lines changed

5 files changed

+63
-41
lines changed

src/mdio/converters/segy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def _add_raw_headers_to_template(mdio_template: AbstractDatasetTemplate) -> Abst
353353
- zstd compressor
354354
- No additional metadata
355355
- Chunked the same as the Headers variable
356+
356357
Args:
357358
mdio_template: The MDIO template to mutate
358359
Returns:

src/mdio/segy/_disaster_recovery_wrapper.py

Lines changed: 0 additions & 33 deletions
This file was deleted.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Consumer-side utility to get both raw and transformed header data with single filesystem read."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
9+
if TYPE_CHECKING:
10+
from numpy.typing import NDArray
11+
from segy import SegyFile
12+
13+
14+
class SegyFileRawTraceWrapper:
15+
def __init__(self, segy_file: SegyFile, indices: int | list[int] | NDArray | slice):
16+
self.segy_file = segy_file
17+
self.indices = indices
18+
19+
self.idx = self.segy_file.trace.normalize_and_validate_query(self.indices)
20+
self.trace_buffer_array = self.segy_file.trace.fetch(self.idx, raw=True)
21+
22+
self.trace_view = self.trace_buffer_array.view(self.segy_file.spec.trace.dtype)
23+
24+
self.trace_decode_pipeline = self.segy_file.accessors.trace_decode_pipeline
25+
self.decoded_traces = None # decode later when not-raw header/sample is called
26+
27+
def _ensure_decoded(self) -> None:
28+
"""Apply trace decoding pipeline if not already done."""
29+
if self.decoded_traces is not None: # already done
30+
return
31+
self.decoded_traces = self.trace_decode_pipeline.apply(self.trace_view.copy())
32+
33+
@property
34+
def raw_header(self) -> NDArray:
35+
"""Get byte array view of the raw headers."""
36+
header_itemsize = self.segy_file.spec.trace.header.itemsize # should be 240
37+
return self.trace_view.header.view(np.dtype((np.void, header_itemsize)))
38+
39+
@property
40+
def header(self) -> NDArray:
41+
"""Get decoded header."""
42+
self._ensure_decoded() # decode when needed in-place to avoid copy.
43+
return self.decoded_traces.header
44+
45+
@property
46+
def sample(self) -> NDArray:
47+
"""Get decoded trace samples."""
48+
self._ensure_decoded() # decode when needed in-place to avoid copy.
49+
return self.decoded_traces.sample

src/mdio/segy/_workers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from mdio.api.io import to_mdio
1414
from mdio.builder.schemas.dtype import ScalarType
15-
from mdio.segy._disaster_recovery_wrapper import SegyFileTraceDataWrapper
15+
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
1616

1717
if TYPE_CHECKING:
1818
from segy.arrays import HeaderArray
@@ -139,7 +139,7 @@ def trace_worker( # noqa: PLR0913
139139
# For that reason, we have wrapped the accessors to provide an interface that can be removed
140140
# and not require additional changes to the below code.
141141
# NOTE: The `raw_header_key` code block should be removed in full as it will become dead code.
142-
traces = SegyFileTraceDataWrapper(segy_file, live_trace_indexes)
142+
traces = SegyFileRawTraceWrapper(segy_file, live_trace_indexes)
143143

144144
ds_to_write = dataset[worker_variables]
145145

tests/unit/test_disaster_recovery_wrapper.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import tempfile
1111
from pathlib import Path
12+
from typing import TYPE_CHECKING
13+
from typing import Any
1214

1315
import numpy as np
1416
import pytest
@@ -19,7 +21,10 @@
1921
from segy.schema import SegySpec
2022
from segy.standards import get_segy_standard
2123

22-
from mdio.segy._disaster_recovery_wrapper import SegyFileTraceDataWrapper
24+
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
25+
26+
if TYPE_CHECKING:
27+
from collections.abc import Generator
2328

2429
SAMPLES_PER_TRACE = 1501
2530

@@ -28,7 +33,7 @@ class TestDisasterRecoveryWrapper:
2833
"""Test cases for disaster recovery wrapper functionality."""
2934

3035
@pytest.fixture
31-
def temp_dir(self) -> Path:
36+
def temp_dir(self) -> Generator[Path, Any, None]:
3237
"""Create a temporary directory for test files."""
3338
with tempfile.TemporaryDirectory() as tmp_dir:
3439
yield Path(tmp_dir)
@@ -140,7 +145,7 @@ def test_wrapper_basic_functionality(self, temp_dir: Path, basic_segy_spec: Segy
140145

141146
# Test single trace
142147
trace_idx = 3
143-
wrapper = SegyFileTraceDataWrapper(segy_file, trace_idx)
148+
wrapper = SegyFileRawTraceWrapper(segy_file, trace_idx)
144149

145150
# Test that properties are accessible
146151
assert wrapper.header is not None
@@ -184,7 +189,7 @@ def test_wrapper_with_multiple_traces(self, temp_dir: Path, basic_segy_spec: Seg
184189

185190
# Test with list of indices
186191
trace_indices = [0, 2, 4]
187-
wrapper = SegyFileTraceDataWrapper(segy_file, trace_indices)
192+
wrapper = SegyFileRawTraceWrapper(segy_file, trace_indices)
188193

189194
# Test that properties work with multiple traces
190195
assert wrapper.header is not None
@@ -222,7 +227,7 @@ def test_wrapper_with_slice_indices(self, temp_dir: Path, basic_segy_spec: SegyS
222227
segy_file = SegyFile(segy_path, spec=spec)
223228

224229
# Test with slice
225-
wrapper = SegyFileTraceDataWrapper(segy_file, slice(5, 15))
230+
wrapper = SegyFileRawTraceWrapper(segy_file, slice(5, 15))
226231

227232
# Test that properties work with slice
228233
assert wrapper.header is not None
@@ -269,7 +274,7 @@ def test_different_index_types(
269274
segy_file = SegyFile(segy_path, spec=spec)
270275

271276
# Create wrapper with different index types
272-
wrapper = SegyFileTraceDataWrapper(segy_file, trace_indices)
277+
wrapper = SegyFileRawTraceWrapper(segy_file, trace_indices)
273278

274279
# Basic validation that we got results
275280
assert wrapper.header is not None

0 commit comments

Comments
 (0)