Skip to content

Commit d34b795

Browse files
committed
Begin work on tests
1 parent 8d6e517 commit d34b795

File tree

2 files changed

+190
-176
lines changed

2 files changed

+190
-176
lines changed

tests/integration/test_segy_import_export_masked.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,16 @@ def generate_selection_mask(selection_conf: SelectionMaskConfig, grid_conf: Grid
282282

283283

284284
@pytest.fixture
285-
def export_masked_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
285+
def export_masked_path(tmp_path_factory: pytest.TempPathFactory, raw_headers_env: None) -> Path: # noqa: ARG001
286286
"""Fixture that generates temp directory for export tests."""
287+
# Create path suffix based on current raw headers environment variable
288+
# raw_headers_env dependency ensures the environment variable is set before this runs
289+
raw_headers_enabled = os.getenv("MDIO__DO_RAW_HEADERS") == "1"
290+
path_suffix = "with_raw_headers" if raw_headers_enabled else "without_raw_headers"
291+
287292
if DEBUG_MODE:
288-
return Path("TMP/export_masked")
289-
return tmp_path_factory.getbasetemp() / "export_masked"
293+
return Path(f"TMP/export_masked_{path_suffix}")
294+
return tmp_path_factory.getbasetemp() / f"export_masked_{path_suffix}"
290295

291296

292297
@pytest.fixture
@@ -300,8 +305,33 @@ def raw_headers_env(request: pytest.FixtureRequest) -> None:
300305

301306
yield
302307

303-
# Cleanup after test
308+
# Cleanup after test - both environment variable and template state
304309
os.environ.pop("MDIO__DO_RAW_HEADERS", None)
310+
311+
# Clean up any template modifications to ensure test isolation
312+
from mdio.builder.template_registry import TemplateRegistry
313+
registry = TemplateRegistry.get_instance()
314+
315+
# Reset any templates that might have been modified with raw headers
316+
template_names = ["PostStack2DTime", "PostStack3DTime", "PreStackCdpOffsetGathers2DTime",
317+
"PreStackCdpOffsetGathers3DTime", "PreStackShotGathers2DTime",
318+
"PreStackShotGathers3DTime", "PreStackCocaGathers3DTime"]
319+
320+
for template_name in template_names:
321+
try:
322+
template = registry.get(template_name)
323+
# Remove raw headers enhancement if present
324+
if hasattr(template, "_mdio_raw_headers_enhanced"):
325+
delattr(template, "_mdio_raw_headers_enhanced")
326+
# The enhancement is applied by monkey-patching _add_variables
327+
# We need to restore it to the original method from the class
328+
# Since we can't easily restore the exact original, we'll get a fresh instance
329+
template_class = type(template)
330+
if hasattr(template_class, '_add_variables'):
331+
template._add_variables = template_class._add_variables.__get__(template, template_class)
332+
except KeyError:
333+
# Template not found, skip
334+
continue
305335

306336

307337
@pytest.mark.parametrize(
@@ -471,3 +501,64 @@ def test_export_masked(
471501
# https://github.com/TGSAI/mdio-python/issues/610
472502
assert_array_equal(actual_sgy.trace[:].header, expected_sgy.trace[expected_trc_idx].header)
473503
assert_array_equal(actual_sgy.trace[:].sample, expected_sgy.trace[expected_trc_idx].sample)
504+
505+
def test_raw_headers_byte_preservation(
506+
self,
507+
test_conf: MaskedExportConfig,
508+
export_masked_path: Path,
509+
raw_headers_env: None, # noqa: ARG002
510+
) -> None:
511+
"""Test that raw headers are preserved byte-for-byte when MDIO__DO_RAW_HEADERS=1."""
512+
grid_conf, segy_factory_conf, _, _ = test_conf
513+
segy_path = export_masked_path / f"{grid_conf.name}.sgy"
514+
mdio_path = export_masked_path / f"{grid_conf.name}.mdio"
515+
516+
# Open MDIO dataset
517+
ds = open_mdio(mdio_path)
518+
519+
# Check if raw_headers should exist based on environment variable
520+
has_raw_headers = "raw_headers" in ds.data_vars
521+
if os.getenv("MDIO__DO_RAW_HEADERS") == "1":
522+
assert has_raw_headers, "raw_headers should be present when MDIO__DO_RAW_HEADERS=1"
523+
else:
524+
assert not has_raw_headers, f"raw_headers should not be present when MDIO__DO_RAW_HEADERS is not set\n {ds}"
525+
return # Exit early if raw_headers are not expected
526+
527+
# Get data (only if raw_headers exist)
528+
raw_headers_data = ds.raw_headers.values
529+
trace_mask = ds.trace_mask.values
530+
531+
# Verify 240-byte headers
532+
assert raw_headers_data.dtype.itemsize == 240, f"Expected 240-byte headers, got {raw_headers_data.dtype.itemsize}"
533+
534+
# Read raw bytes directly from SEG-Y file
535+
def read_segy_trace_header(trace_index: int) -> bytes:
536+
"""Read 240-byte trace header directly from SEG-Y file."""
537+
with open(segy_path, 'rb') as f:
538+
# Skip text (3200) + binary (400) headers = 3600 bytes
539+
f.seek(3600)
540+
# Each trace: 240 byte header + (num_samples * 4) byte samples
541+
trace_size = 240 + (segy_factory_conf.num_samples * 4)
542+
trace_offset = trace_index * trace_size
543+
f.seek(trace_offset, 1) # Seek relative to current position
544+
return f.read(240)
545+
546+
# Compare all valid traces byte-by-byte
547+
segy_trace_idx = 0
548+
flat_mask = trace_mask.ravel()
549+
flat_raw_headers = raw_headers_data.reshape(-1, raw_headers_data.shape[-1])
550+
551+
for grid_idx in range(flat_mask.size):
552+
if not flat_mask[grid_idx]:
553+
continue
554+
555+
# Get MDIO header as bytes
556+
mdio_header_bytes = np.frombuffer(flat_raw_headers[grid_idx].tobytes(), dtype=np.uint8)
557+
558+
# Get SEG-Y header as raw bytes directly from file
559+
segy_raw_header_bytes = read_segy_trace_header(segy_trace_idx)
560+
segy_header_bytes = np.frombuffer(segy_raw_header_bytes, dtype=np.uint8)
561+
562+
# Compare byte-by-byte
563+
assert_array_equal(mdio_header_bytes, segy_header_bytes)
564+
segy_trace_idx += 1

0 commit comments

Comments
 (0)