@@ -282,11 +282,16 @@ def generate_selection_mask(selection_conf: SelectionMaskConfig, grid_conf: Grid
282282
283283
284284@pytest .fixture
285- def export_masked_path (tmp_path_factory : pytest .TempPathFactory ) -> Path :
285+ def export_masked_path (tmp_path_factory : pytest .TempPathFactory , raw_headers_env : None ) -> Path : # noqa: ARG001
286286 """Fixture that generates temp directory for export tests."""
287+ # Create path suffix based on current raw headers environment variable
288+ # raw_headers_env dependency ensures the environment variable is set before this runs
289+ raw_headers_enabled = os .getenv ("MDIO__DO_RAW_HEADERS" ) == "1"
290+ path_suffix = "with_raw_headers" if raw_headers_enabled else "without_raw_headers"
291+
287292 if DEBUG_MODE :
288- return Path ("TMP/export_masked " )
289- return tmp_path_factory .getbasetemp () / "export_masked "
293+ return Path (f "TMP/export_masked_ { path_suffix } " )
294+ return tmp_path_factory .getbasetemp () / f"export_masked_ { path_suffix } "
290295
291296
292297@pytest .fixture
@@ -300,8 +305,33 @@ def raw_headers_env(request: pytest.FixtureRequest) -> None:
300305
301306 yield
302307
303- # Cleanup after test
308+ # Cleanup after test - both environment variable and template state
304309 os .environ .pop ("MDIO__DO_RAW_HEADERS" , None )
310+
311+ # Clean up any template modifications to ensure test isolation
312+ from mdio .builder .template_registry import TemplateRegistry
313+ registry = TemplateRegistry .get_instance ()
314+
315+ # Reset any templates that might have been modified with raw headers
316+ template_names = ["PostStack2DTime" , "PostStack3DTime" , "PreStackCdpOffsetGathers2DTime" ,
317+ "PreStackCdpOffsetGathers3DTime" , "PreStackShotGathers2DTime" ,
318+ "PreStackShotGathers3DTime" , "PreStackCocaGathers3DTime" ]
319+
320+ for template_name in template_names :
321+ try :
322+ template = registry .get (template_name )
323+ # Remove raw headers enhancement if present
324+ if hasattr (template , "_mdio_raw_headers_enhanced" ):
325+ delattr (template , "_mdio_raw_headers_enhanced" )
326+ # The enhancement is applied by monkey-patching _add_variables
327+ # We need to restore it to the original method from the class
328+ # Since we can't easily restore the exact original, we'll get a fresh instance
329+ template_class = type (template )
330+ if hasattr (template_class , '_add_variables' ):
331+ template ._add_variables = template_class ._add_variables .__get__ (template , template_class )
332+ except KeyError :
333+ # Template not found, skip
334+ continue
305335
306336
307337@pytest .mark .parametrize (
@@ -471,3 +501,64 @@ def test_export_masked(
471501 # https://github.com/TGSAI/mdio-python/issues/610
472502 assert_array_equal (actual_sgy .trace [:].header , expected_sgy .trace [expected_trc_idx ].header )
473503 assert_array_equal (actual_sgy .trace [:].sample , expected_sgy .trace [expected_trc_idx ].sample )
504+
505+ def test_raw_headers_byte_preservation (
506+ self ,
507+ test_conf : MaskedExportConfig ,
508+ export_masked_path : Path ,
509+ raw_headers_env : None , # noqa: ARG002
510+ ) -> None :
511+ """Test that raw headers are preserved byte-for-byte when MDIO__DO_RAW_HEADERS=1."""
512+ grid_conf , segy_factory_conf , _ , _ = test_conf
513+ segy_path = export_masked_path / f"{ grid_conf .name } .sgy"
514+ mdio_path = export_masked_path / f"{ grid_conf .name } .mdio"
515+
516+ # Open MDIO dataset
517+ ds = open_mdio (mdio_path )
518+
519+ # Check if raw_headers should exist based on environment variable
520+ has_raw_headers = "raw_headers" in ds .data_vars
521+ if os .getenv ("MDIO__DO_RAW_HEADERS" ) == "1" :
522+ assert has_raw_headers , "raw_headers should be present when MDIO__DO_RAW_HEADERS=1"
523+ else :
524+ assert not has_raw_headers , f"raw_headers should not be present when MDIO__DO_RAW_HEADERS is not set\n { ds } "
525+ return # Exit early if raw_headers are not expected
526+
527+ # Get data (only if raw_headers exist)
528+ raw_headers_data = ds .raw_headers .values
529+ trace_mask = ds .trace_mask .values
530+
531+ # Verify 240-byte headers
532+ assert raw_headers_data .dtype .itemsize == 240 , f"Expected 240-byte headers, got { raw_headers_data .dtype .itemsize } "
533+
534+ # Read raw bytes directly from SEG-Y file
535+ def read_segy_trace_header (trace_index : int ) -> bytes :
536+ """Read 240-byte trace header directly from SEG-Y file."""
537+ with open (segy_path , 'rb' ) as f :
538+ # Skip text (3200) + binary (400) headers = 3600 bytes
539+ f .seek (3600 )
540+ # Each trace: 240 byte header + (num_samples * 4) byte samples
541+ trace_size = 240 + (segy_factory_conf .num_samples * 4 )
542+ trace_offset = trace_index * trace_size
543+ f .seek (trace_offset , 1 ) # Seek relative to current position
544+ return f .read (240 )
545+
546+ # Compare all valid traces byte-by-byte
547+ segy_trace_idx = 0
548+ flat_mask = trace_mask .ravel ()
549+ flat_raw_headers = raw_headers_data .reshape (- 1 , raw_headers_data .shape [- 1 ])
550+
551+ for grid_idx in range (flat_mask .size ):
552+ if not flat_mask [grid_idx ]:
553+ continue
554+
555+ # Get MDIO header as bytes
556+ mdio_header_bytes = np .frombuffer (flat_raw_headers [grid_idx ].tobytes (), dtype = np .uint8 )
557+
558+ # Get SEG-Y header as raw bytes directly from file
559+ segy_raw_header_bytes = read_segy_trace_header (segy_trace_idx )
560+ segy_header_bytes = np .frombuffer (segy_raw_header_bytes , dtype = np .uint8 )
561+
562+ # Compare byte-by-byte
563+ assert_array_equal (mdio_header_bytes , segy_header_bytes )
564+ segy_trace_idx += 1
0 commit comments