Skip to content

Commit 266a3ba

Browse files
committed
Begin work on tests
1 parent f433849 commit 266a3ba

File tree

1 file changed

+347
-0
lines changed

1 file changed

+347
-0
lines changed
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
"""Test harness for disaster recovery wrapper.
2+
3+
This module tests the _disaster_recovery_wrapper.py functionality by creating
4+
test SEGY files with different configurations and validating that the raw headers
5+
from get_header_raw_and_transformed match the bytes on disk.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import tempfile
11+
from pathlib import Path
12+
from typing import TYPE_CHECKING
13+
14+
import numpy as np
15+
import pytest
16+
from segy import SegyFile
17+
from segy.factory import SegyFactory
18+
from segy.schema import Endianness
19+
from segy.schema import HeaderField
20+
from segy.schema import SegySpec
21+
from segy.standards import get_segy_standard
22+
23+
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed
24+
25+
if TYPE_CHECKING:
26+
from numpy.typing import NDArray
27+
28+
29+
class TestDisasterRecoveryWrapper:
30+
"""Test cases for disaster recovery wrapper functionality."""
31+
32+
@pytest.fixture
33+
def temp_dir(self) -> Path:
34+
"""Create a temporary directory for test files."""
35+
with tempfile.TemporaryDirectory() as tmp_dir:
36+
yield Path(tmp_dir)
37+
38+
@pytest.fixture
39+
def basic_segy_spec(self) -> SegySpec:
40+
"""Create a basic SEGY specification for testing."""
41+
spec = get_segy_standard(1.0)
42+
43+
# Add basic header fields for inline/crossline
44+
header_fields = [
45+
HeaderField(name="inline", byte=189, format="int32"),
46+
HeaderField(name="crossline", byte=193, format="int32"),
47+
HeaderField(name="cdp_x", byte=181, format="int32"),
48+
HeaderField(name="cdp_y", byte=185, format="int32"),
49+
]
50+
51+
return spec.customize(trace_header_fields=header_fields)
52+
53+
@pytest.fixture(params=[
54+
{"endianness": Endianness.BIG, "data_format": 1, "name": "big_endian_ibm"},
55+
{"endianness": Endianness.BIG, "data_format": 5, "name": "big_endian_ieee"},
56+
{"endianness": Endianness.LITTLE, "data_format": 1, "name": "little_endian_ibm"},
57+
{"endianness": Endianness.LITTLE, "data_format": 5, "name": "little_endian_ieee"},
58+
])
59+
def segy_config(self, request) -> dict:
60+
"""Parameterized fixture for different SEGY configurations."""
61+
return request.param
62+
63+
def create_test_segy_file(
64+
self,
65+
spec: SegySpec,
66+
num_traces: int,
67+
samples_per_trace: int,
68+
output_path: Path,
69+
endianness: Endianness = Endianness.BIG,
70+
data_format: int = 1, # 1=IBM float, 5=IEEE float
71+
inline_range: tuple[int, int] = (1, 5),
72+
crossline_range: tuple[int, int] = (1, 5),
73+
) -> SegySpec:
74+
"""Create a test SEGY file with synthetic data."""
75+
# Update spec with desired endianness
76+
spec = spec.model_copy(update={"endianness": endianness})
77+
78+
factory = SegyFactory(spec=spec, samples_per_trace=samples_per_trace)
79+
80+
# Create synthetic header data
81+
headers = factory.create_trace_header_template(num_traces)
82+
samples = factory.create_trace_sample_template(num_traces)
83+
84+
# Set inline/crossline values
85+
inline_start, inline_end = inline_range
86+
crossline_start, crossline_end = crossline_range
87+
88+
# Create a simple grid
89+
inlines = np.arange(inline_start, inline_end + 1)
90+
crosslines = np.arange(crossline_start, crossline_end + 1)
91+
92+
trace_idx = 0
93+
for inline in inlines:
94+
for crossline in crosslines:
95+
if trace_idx >= num_traces:
96+
break
97+
98+
headers["inline"][trace_idx] = inline
99+
headers["crossline"][trace_idx] = crossline
100+
headers["cdp_x"][trace_idx] = inline * 100 # Simple coordinate calculation
101+
headers["cdp_y"][trace_idx] = crossline * 100
102+
103+
# Create simple synthetic trace data
104+
samples[trace_idx] = np.linspace(0, 1, samples_per_trace)
105+
106+
trace_idx += 1
107+
108+
# Write the SEGY file with custom binary header
109+
binary_header_updates = {"data_sample_format": data_format}
110+
with output_path.open("wb") as f:
111+
f.write(factory.create_textual_header())
112+
f.write(factory.create_binary_header(update=binary_header_updates))
113+
f.write(factory.create_traces(headers, samples))
114+
115+
return spec
116+
117+
def extract_header_bytes_from_file(
118+
self, segy_path: Path, trace_index: int, byte_start: int, byte_length: int
119+
) -> NDArray:
120+
"""Extract specific bytes from a trace header in the SEGY file."""
121+
with open(segy_path, "rb") as f:
122+
# Skip text header (3200 bytes) + binary header (400 bytes)
123+
header_offset = 3600
124+
125+
# Each trace: 240 byte header + samples
126+
trace_size = 240 + 1501 * 4 # Assuming 1501 samples, 4 bytes each
127+
trace_offset = header_offset + trace_index * trace_size
128+
129+
f.seek(trace_offset + byte_start - 1) # SEGY is 1-based
130+
header_bytes = f.read(byte_length)
131+
132+
return np.frombuffer(header_bytes, dtype=np.uint8)
133+
134+
def test_header_validation_configurations(
135+
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict
136+
) -> None:
137+
"""Test header validation with different SEGY configurations."""
138+
config_name = segy_config["name"]
139+
endianness = segy_config["endianness"]
140+
data_format = segy_config["data_format"]
141+
142+
segy_path = temp_dir / f"test_{config_name}.segy"
143+
144+
# Create test SEGY file
145+
num_traces = 10
146+
samples_per_trace = 1501
147+
148+
spec = self.create_test_segy_file(
149+
spec=basic_segy_spec,
150+
num_traces=num_traces,
151+
samples_per_trace=samples_per_trace,
152+
output_path=segy_path,
153+
endianness=endianness,
154+
data_format=data_format,
155+
)
156+
157+
# Load the SEGY file
158+
segy_file = SegyFile(segy_path, spec=spec)
159+
160+
# Test a few traces
161+
test_indices = [0, 3, 7]
162+
163+
for trace_idx in test_indices:
164+
# Get raw and transformed headers
165+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
166+
segy_file=segy_file,
167+
indices=trace_idx,
168+
do_reverse_transforms=True
169+
)
170+
171+
# Extract bytes from disk for inline (bytes 189-192) and crossline (bytes 193-196)
172+
inline_bytes_disk = self.extract_header_bytes_from_file(
173+
segy_path, trace_idx, 189, 4
174+
)
175+
crossline_bytes_disk = self.extract_header_bytes_from_file(
176+
segy_path, trace_idx, 193, 4
177+
)
178+
179+
# Convert raw headers to bytes for comparison
180+
if raw_headers is not None:
181+
# Extract inline and crossline from raw headers
182+
raw_inline_bytes = np.frombuffer(
183+
raw_headers["inline"].tobytes(), dtype=np.uint8
184+
)[:4]
185+
raw_crossline_bytes = np.frombuffer(
186+
raw_headers["crossline"].tobytes(), dtype=np.uint8
187+
)[:4]
188+
189+
# Compare bytes
190+
assert np.array_equal(raw_inline_bytes, inline_bytes_disk), \
191+
f"Inline bytes mismatch for trace {trace_idx} in {config_name}"
192+
assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), \
193+
f"Crossline bytes mismatch for trace {trace_idx} in {config_name}"
194+
195+
def test_header_validation_no_transforms(
196+
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict
197+
) -> None:
198+
"""Test header validation when transforms are disabled."""
199+
config_name = segy_config["name"]
200+
endianness = segy_config["endianness"]
201+
data_format = segy_config["data_format"]
202+
203+
segy_path = temp_dir / f"test_no_transforms_{config_name}.segy"
204+
205+
# Create test SEGY file
206+
num_traces = 5
207+
samples_per_trace = 1501
208+
209+
spec = self.create_test_segy_file(
210+
spec=basic_segy_spec,
211+
num_traces=num_traces,
212+
samples_per_trace=samples_per_trace,
213+
output_path=segy_path,
214+
endianness=endianness,
215+
data_format=data_format,
216+
)
217+
218+
# Load the SEGY file
219+
segy_file = SegyFile(segy_path, spec=spec)
220+
221+
# Get headers without reverse transforms
222+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
223+
segy_file=segy_file,
224+
indices=slice(None), # All traces
225+
do_reverse_transforms=False
226+
)
227+
228+
# When transforms are disabled, raw_headers should be None
229+
assert raw_headers is None
230+
231+
# Transformed headers should still be available
232+
assert transformed_headers is not None
233+
assert len(transformed_headers) == num_traces
234+
235+
def test_multiple_traces_validation(
236+
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict
237+
) -> None:
238+
"""Test validation with multiple traces at once."""
239+
config_name = segy_config["name"]
240+
endianness = segy_config["endianness"]
241+
data_format = segy_config["data_format"]
242+
243+
segy_path = temp_dir / f"test_multiple_traces_{config_name}.segy"
244+
245+
# Create test SEGY file with more traces
246+
num_traces = 25 # 5x5 grid
247+
samples_per_trace = 1501
248+
249+
spec = self.create_test_segy_file(
250+
spec=basic_segy_spec,
251+
num_traces=num_traces,
252+
samples_per_trace=samples_per_trace,
253+
output_path=segy_path,
254+
endianness=endianness,
255+
data_format=data_format,
256+
inline_range=(1, 5),
257+
crossline_range=(1, 5),
258+
)
259+
260+
# Load the SEGY file
261+
segy_file = SegyFile(segy_path, spec=spec)
262+
263+
# Get all traces
264+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
265+
segy_file=segy_file,
266+
indices=slice(None), # All traces
267+
do_reverse_transforms=True
268+
)
269+
270+
# Validate each trace
271+
for trace_idx in range(num_traces):
272+
# Extract bytes from disk
273+
inline_bytes_disk = self.extract_header_bytes_from_file(
274+
segy_path, trace_idx, 189, 4
275+
)
276+
crossline_bytes_disk = self.extract_header_bytes_from_file(
277+
segy_path, trace_idx, 193, 4
278+
)
279+
280+
# Extract from raw headers
281+
raw_inline_bytes = np.frombuffer(
282+
raw_headers["inline"][trace_idx].tobytes(), dtype=np.uint8
283+
)[:4]
284+
raw_crossline_bytes = np.frombuffer(
285+
raw_headers["crossline"][trace_idx].tobytes(), dtype=np.uint8
286+
)[:4]
287+
288+
# Compare
289+
assert np.array_equal(raw_inline_bytes, inline_bytes_disk), \
290+
f"Inline bytes mismatch for trace {trace_idx} in {config_name}"
291+
assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), \
292+
f"Crossline bytes mismatch for trace {trace_idx} in {config_name}"
293+
294+
@pytest.mark.parametrize("trace_indices", [
295+
0, # Single trace
296+
[0, 2, 4], # Multiple specific traces
297+
slice(1, 4), # Range of traces
298+
])
299+
def test_different_index_types(
300+
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict, trace_indices
301+
) -> None:
302+
"""Test with different types of trace indices."""
303+
config_name = segy_config["name"]
304+
endianness = segy_config["endianness"]
305+
data_format = segy_config["data_format"]
306+
307+
segy_path = temp_dir / f"test_index_types_{config_name}.segy"
308+
309+
# Create test SEGY file
310+
num_traces = 10
311+
samples_per_trace = 1501
312+
313+
spec = self.create_test_segy_file(
314+
spec=basic_segy_spec,
315+
num_traces=num_traces,
316+
samples_per_trace=samples_per_trace,
317+
output_path=segy_path,
318+
endianness=endianness,
319+
data_format=data_format,
320+
)
321+
322+
# Load the SEGY file
323+
segy_file = SegyFile(segy_path, spec=spec)
324+
325+
# Get headers with different index types
326+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
327+
segy_file=segy_file,
328+
indices=trace_indices,
329+
do_reverse_transforms=True
330+
)
331+
332+
# Basic validation that we got results
333+
assert raw_headers is not None
334+
assert transformed_headers is not None
335+
assert traces is not None
336+
337+
# Check that the number of results matches expectation
338+
if isinstance(trace_indices, int):
339+
expected_count = 1
340+
elif isinstance(trace_indices, list):
341+
expected_count = len(trace_indices)
342+
elif isinstance(trace_indices, slice):
343+
expected_count = len(range(*trace_indices.indices(num_traces)))
344+
else:
345+
expected_count = 1
346+
347+
assert len(transformed_headers) == expected_count

0 commit comments

Comments
 (0)