@@ -282,11 +282,16 @@ def generate_selection_mask(selection_conf: SelectionMaskConfig, grid_conf: Grid
282282
283283
284284@pytest .fixture
285- def export_masked_path (tmp_path_factory : pytest .TempPathFactory ) -> Path :
285+ def export_masked_path (tmp_path_factory : pytest .TempPathFactory , raw_headers_env : None ) -> Path : # noqa: ARG001
286286 """Fixture that generates temp directory for export tests."""
287+ # Create path suffix based on current raw headers environment variable
288+ # raw_headers_env dependency ensures the environment variable is set before this runs
289+ raw_headers_enabled = os .getenv ("MDIO__DO_RAW_HEADERS" ) == "1"
290+ path_suffix = "with_raw_headers" if raw_headers_enabled else "without_raw_headers"
291+
287292 if DEBUG_MODE :
288- return Path ("TMP/export_masked " )
289- return tmp_path_factory .getbasetemp () / "export_masked "
293+ return Path (f "TMP/export_masked_ { path_suffix } " )
294+ return tmp_path_factory .getbasetemp () / f"export_masked_ { path_suffix } "
290295
291296
292297@pytest .fixture
@@ -300,9 +305,39 @@ def raw_headers_env(request: pytest.FixtureRequest) -> None:
300305
301306 yield
302307
303- # Cleanup after test
308+ # Cleanup after test - both environment variable and template state
304309 os .environ .pop ("MDIO__DO_RAW_HEADERS" , None )
305310
311+ # Clean up any template modifications to ensure test isolation
312+ registry = TemplateRegistry .get_instance ()
313+
314+ # Reset any templates that might have been modified with raw headers
315+ template_names = [
316+ "PostStack2DTime" ,
317+ "PostStack3DTime" ,
318+ "PreStackCdpOffsetGathers2DTime" ,
319+ "PreStackCdpOffsetGathers3DTime" ,
320+ "PreStackShotGathers2DTime" ,
321+ "PreStackShotGathers3DTime" ,
322+ "PreStackCocaGathers3DTime" ,
323+ ]
324+
325+ for template_name in template_names :
326+ try :
327+ template = registry .get (template_name )
328+ # Remove raw headers enhancement if present
329+ if hasattr (template , "_mdio_raw_headers_enhanced" ):
330+ delattr (template , "_mdio_raw_headers_enhanced" )
331+ # The enhancement is applied by monkey-patching _add_variables
332+ # We need to restore it to the original method from the class
333+ # Since we can't easily restore the exact original, we'll get a fresh instance
334+ template_class = type (template )
335+ if hasattr (template_class , "_add_variables" ):
336+ template ._add_variables = template_class ._add_variables .__get__ (template , template_class )
337+ except KeyError :
338+ # Template not found, skip
339+ continue
340+
306341
307342@pytest .mark .parametrize (
308343 "test_conf" ,
@@ -471,3 +506,69 @@ def test_export_masked(
471506 # https://github.com/TGSAI/mdio-python/issues/610
472507 assert_array_equal (actual_sgy .trace [:].header , expected_sgy .trace [expected_trc_idx ].header )
473508 assert_array_equal (actual_sgy .trace [:].sample , expected_sgy .trace [expected_trc_idx ].sample )
509+
510+ def test_raw_headers_byte_preservation (
511+ self ,
512+ test_conf : MaskedExportConfig ,
513+ export_masked_path : Path ,
514+ raw_headers_env : None , # noqa: ARG002
515+ ) -> None :
516+ """Test that raw headers are preserved byte-for-byte when MDIO__DO_RAW_HEADERS=1."""
517+ grid_conf , segy_factory_conf , _ , _ = test_conf
518+ segy_path = export_masked_path / f"{ grid_conf .name } .sgy"
519+ mdio_path = export_masked_path / f"{ grid_conf .name } .mdio"
520+
521+ # Open MDIO dataset
522+ ds = open_mdio (mdio_path )
523+
524+ # Check if raw_headers should exist based on environment variable
525+ has_raw_headers = "raw_headers" in ds .data_vars
526+ if os .getenv ("MDIO__DO_RAW_HEADERS" ) == "1" :
527+ assert has_raw_headers , "raw_headers should be present when MDIO__DO_RAW_HEADERS=1"
528+ else :
529+ assert not has_raw_headers , f"raw_headers should not be present when MDIO__DO_RAW_HEADERS is not set\n { ds } "
530+ return # Exit early if raw_headers are not expected
531+
532+ # Get data (only if raw_headers exist)
533+ raw_headers_data = ds .raw_headers .values
534+ trace_mask = ds .trace_mask .values
535+
536+ # Verify 240-byte headers
537+ assert raw_headers_data .dtype .itemsize == 240 , (
538+ f"Expected 240-byte headers, got { raw_headers_data .dtype .itemsize } "
539+ )
540+
541+ # Read raw bytes directly from SEG-Y file
542+ def read_segy_trace_header (trace_index : int ) -> bytes :
543+ """Read 240-byte trace header directly from SEG-Y file."""
544+ # with open(segy_path, "rb") as f:
545+ with Path .open (segy_path , "rb" ) as f :
546+ # Skip text (3200) + binary (400) headers = 3600 bytes
547+ f .seek (3600 )
548+ # Each trace: 240 byte header + (num_samples * 4) byte samples
549+ trace_size = 240 + (segy_factory_conf .num_samples * 4 )
550+ trace_offset = trace_index * trace_size
551+ f .seek (trace_offset , 1 ) # Seek relative to current position
552+ return f .read (240 )
553+
554+ # Compare all valid traces byte-by-byte
555+ segy_trace_idx = 0
556+ flat_mask = trace_mask .ravel ()
557+ flat_raw_headers = raw_headers_data .ravel () # Flatten to 1D array of 240-byte header records
558+
559+ for grid_idx in range (flat_mask .size ):
560+ if not flat_mask [grid_idx ]:
561+ print (f"Skipping trace { grid_idx } because it is masked" )
562+ continue
563+
564+ # Get MDIO header as bytes - convert single header record to bytes
565+ header_record = flat_raw_headers [grid_idx ]
566+ mdio_header_bytes = np .frombuffer (header_record .tobytes (), dtype = np .uint8 )
567+
568+ # Get SEG-Y header as raw bytes directly from file
569+ segy_raw_header_bytes = read_segy_trace_header (segy_trace_idx )
570+ segy_header_bytes = np .frombuffer (segy_raw_header_bytes , dtype = np .uint8 )
571+
572+ assert_array_equal (mdio_header_bytes , segy_header_bytes )
573+
574+ segy_trace_idx += 1
0 commit comments