88from typing import Any
99
1010import numpy as np
11- import zarr
1211from numcodecs import Blosc
1312from segy import SegyFile
1413from segy .config import SegySettings
1817from mdio .converters .exceptions import GridTraceCountError
1918from mdio .converters .exceptions import GridTraceSparsityError
2019from mdio .core import Grid
21- from mdio .core .factory import MDIOCreateConfig
22- from mdio .core .factory import MDIOVariableConfig
23- from mdio .core .factory import create_empty
24- from mdio .core .utils_write import write_attribute
20+ from mdio .core .utils_write import get_live_mask_chunksize as live_chunks
21+ from mdio .core .v1 .builder import MDIODatasetBuilder as MDIOBuilder
2522from mdio .segy import blocked_io
2623from mdio .segy .compat import mdio_segy_spec
2724from mdio .segy .utilities import get_grid_plan
2825
29- from mdio .core .v1 .builder import MDIODatasetBuilder as MDIOBuilder
30- from mdio .core .utils_write import get_live_mask_chunksize as live_chunks
31-
3226if TYPE_CHECKING :
3327 from collections .abc import Sequence
3428 from pathlib import Path
@@ -423,24 +417,19 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
423417 coordinates = ["live_mask" ],
424418 dimensions = [dim .name for dim in dimensions [:- 1 ]],
425419 metadata = {
426- "chunkGrid" : {
427- "name" : "regular" ,
428- "configuration" : {
429- "chunkShape" : live_chunks (lc )
430- }
431- }
432- }
420+ "chunkGrid" : {"name" : "regular" , "configuration" : {"chunkShape" : live_chunks (lc )}}
421+ },
433422 )
434423
435- print (f"Chunksize: { chunksize } " )
424+ # print(f"Chunksize: {chunksize}")
436425
437426 if chunksize is not None :
438427 metadata = {
439428 "chunkGrid" : {
440429 "name" : "regular" ,
441430 "configuration" : {
442431 "chunkShape" : list (chunksize ),
443- }
432+ },
444433 }
445434 }
446435 else :
@@ -456,8 +445,10 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
456445 ds = builder .to_mdio (store = mdio_path_or_buffer )
457446
458447 import json
459- contract = json .loads (builder .build ().json ())
448+
449+ contract = json .loads (builder .build ().json ())
460450 from rich import print as rprint
451+
461452 oc = {
462453 "metadata" : contract ["metadata" ],
463454 "variables" : contract ["variables" ],
@@ -489,6 +480,7 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
489480 live_mask_array = ds .live_mask
490481 # Cast to MDIODataArray to access the to_mdio method
491482 from mdio .core .v1 ._overloads import MDIODataArray
483+
492484 live_mask_array .__class__ = MDIODataArray
493485
494486 # Build a ChunkIterator over the live_mask (no sample axis)
@@ -576,40 +568,113 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
576568
577569 # zarr.consolidate_metadata(root_group.store)
578570
571+ def validate_segy_schema (segy_schema : dict [str , Any ]) -> None :
572+ """Validate the SEG-Y schema.
573+
574+ Args:
575+ segy_schema: SEG-Y schema
576+
577+ Raises:
578+ ValueError: If schema is missing required fields or has invalid structure
579+ """
580+ if "trace" not in segy_schema :
581+ raise ValueError ("SEG-Y schema must contain 'trace' field" )
582+
583+ if "header_entries" not in segy_schema ["trace" ]:
584+ raise ValueError ("SEG-Y schema trace must contain 'header_entries' field" )
585+
586+ if not isinstance (segy_schema ["trace" ]["header_entries" ], list ):
587+ raise ValueError ("SEG-Y schema trace header_entries must be a list" )
588+
589+
590+ def get_dims (ds : MDIO , segy_schema : dict [str , Any ]) -> dict [str , Any ]:
591+ """Get the dimensions of the MDIO dataset from the SEG-Y schema.
592+
593+ Args:
594+ ds: MDIO dataset
595+ segy_schema: SEG-Y schema
596+ """
597+ target_dims = ds .seismic .dims [:- 1 ]
598+
599+ try :
600+ validate_segy_schema (segy_schema )
601+ except ValueError as e :
602+ raise ValueError (f"Unable to parse SEG-Y schema: { e } " )
603+
604+ trace_headers = segy_schema ["trace" ]["header_entries" ]
605+ ret = {}
606+
607+ for header in trace_headers :
608+ if header ["name" ] in target_dims :
609+ ret [header ["name" ]] = {
610+ "index_name" : header ["name" ],
611+ "index_type" : header ["format" ],
612+ "index_byte" : header ["byte_start" ],
613+ }
614+
615+ if len (ret ) != len (target_dims ):
616+ raise ValueError (f"Not all dimensions were found in the SEG-Y schema. Missing: { target_dims - ret .keys ()} " )
617+
618+ return ret
619+
620+ def get_sample_name (ds : MDIO , grid_dims ) -> str :
621+ """Get the name of the sample dimension from the dataset.
622+
623+ Args:
624+ ds: MDIO dataset
625+ grid_dims: List of grid dimensions
626+
627+ Returns:
628+ str: Name of the sample dimension
629+ """
630+ ds_dims = list (ds .seismic .dims )
631+ for dim in grid_dims :
632+ try :
633+ ds_dims .remove (dim .name )
634+ except ValueError :
635+ pass
636+ return ds_dims [0 ] # Should only be one left
637+
579638
580639def segy_to_mdio_schematized (
581640 segy_schema : dict [str , Any ],
582- mdio_schema : dict [str , Any ],
641+ mdio_schema : dict [str , Any ],
583642 mdio_path_or_buffer : str | Path ,
584643 storage_options_input : dict [str , Any ] | None = None ,
585644 storage_options_output : dict [str , Any ] | None = None ,
586645) -> None :
587646 """Create MDIO dataset from a schema specification using Pydantic v1 models.
588-
647+
589648 Args:
590649 segy_schema: Dictionary containing SEG-Y related schema (currently unused)
591650 mdio_schema: Dictionary containing the MDIO schema specification
592651 mdio_path_or_buffer: Output path for the MDIO file
593652 """
594-
595653 grid_overrides = None # TODO: Implement this maybe?
596654
597- from mdio .core .v1 .factory import from_contract
598655 from mdio .core .v1 ._overloads import MDIO
656+ from mdio .core .v1 .factory import from_contract
657+
599658 serialized_mdio = from_contract (mdio_path_or_buffer , mdio_schema )
600659
601660 ds = MDIO .open (mdio_path_or_buffer ) # Reopen because we needed to do some weird stuff (hacky)
602661
603- index_names = segy_schema ["index_names" ]
604- index_types = segy_schema ["index_types" ]
605- index_bytes = segy_schema ["index_bytes" ]
662+ try :
663+ dims = get_dims (ds , segy_schema )
664+ except ValueError as e :
665+ raise ValueError (f"Unable to parse SEG-Y schema into MDIO schema: { e } " )
666+
667+ index_names = [dims [dim ]["index_name" ] for dim in dims ]
668+ index_types = [dims [dim ]["index_type" ] for dim in dims ]
669+ index_bytes = [dims [dim ]["index_byte" ] for dim in dims ]
670+
606671
607672 chunksize = None
608673 live_mask_valid = False
609674 for variable in mdio_schema ["variables" ]:
610675 if variable ["name" ] == "seismic" :
611676 chunksize = variable ["metadata" ]["chunkGrid" ]["configuration" ]["chunkShape" ]
612- elif variable ["name" ] == "live_mask" :
677+ elif variable ["name" ] == "live_mask" or variable [ "name" ] == "trace_mask" :
613678 live_mask_valid = True
614679
615680 if chunksize is None :
@@ -620,16 +685,18 @@ def segy_to_mdio_schematized(
620685
621686 storage_options_input = storage_options_input or {}
622687 storage_options_output = storage_options_output or {}
623-
624- mdio_spec = mdio_segy_spec () # TODO: I think this may need to be updated to work with our new input schemas
688+
689+ mdio_spec = (
690+ mdio_segy_spec ()
691+ ) # TODO: I think this may need to be updated to work with our new input schemas
625692 segy_settings = SegySettings (storage_options = storage_options_input )
626693 # segy = SegyFile(url=segy_path, spec=mdio_spec, settings=segy_settings)
627694 segy = SegyFile (url = segy_schema ["path" ], spec = mdio_spec , settings = segy_settings )
628695
629696 text_header = segy .text_header
630697 binary_header = segy .binary_header
631698 num_traces = segy .num_traces
632-
699+
633700 # Index the dataset using a spec that interprets the user provided index headers.
634701 index_fields = []
635702 for name , byte , format_ in zip (index_names , index_bytes , index_types , strict = True ):
@@ -643,18 +710,21 @@ def segy_to_mdio_schematized(
643710 chunksize = chunksize ,
644711 grid_overrides = grid_overrides ,
645712 )
713+ dimensions [- 1 ].name = get_sample_name (ds , dimensions )
714+ # print(dimensions)
646715 grid = Grid (dims = dimensions )
647716 grid_density_qc (grid , num_traces )
648717 grid .build_map (index_headers )
649718
719+ # Override the "sample" dimension name
720+
650721 # Set dimension coordinates
651722 new_coords = {dim .name : dim .coords for dim in dimensions }
652723 ds = ds .assign_coords (new_coords )
653724 ds .to_mdio (store = mdio_path_or_buffer , mode = "r+" )
654725
655726 # Set all coordinates which are not dimensions, root Variables, or live_mask
656727
657-
658728 # Check grid validity by ensuring every trace's header-index is within dimension bounds
659729 valid_mask = np .ones (grid .num_traces , dtype = bool )
660730 for d_idx in range (len (grid .header_index_arrays )):
@@ -674,15 +744,19 @@ def segy_to_mdio_schematized(
674744 del valid_mask
675745 gc .collect ()
676746
677- live_mask_array = ds .live_mask # TODO: Make this more robust
747+ coords = ds .seismic .coords # TODO: We also need to iterate over the coords and assign their values in parallel with the live_mask
748+
749+ # live_mask_array = ds.live_mask # TODO: Make this more robust
750+ live_mask_array = ds .trace_mask
678751 from mdio .core .v1 ._overloads import MDIODataArray
752+
679753 live_mask_array .__class__ = MDIODataArray
680754
681755 from mdio .core .indexing import ChunkIterator
682756
683757 chunker = ChunkIterator (live_mask_array , chunk_samples = True )
684758 for chunk_indices in chunker :
685- print (f"chunk_indices: { chunk_indices } " )
759+ # print(f"chunk_indices: {chunk_indices}")
686760 # chunk_indices is a tuple of N–1 slice objects
687761 trace_ids = grid .get_traces_for_chunk (chunk_indices )
688762 if trace_ids .size == 0 :
@@ -720,16 +794,17 @@ def segy_to_mdio_schematized(
720794 del local_coords
721795
722796 # Write the entire block to Zarr at once
723- # live_mask_array.loc[ chunk_indices] = block
724- live_mask_array . isel ( isel_dict ). values [: ] = block
797+ # live_mask_array.isel( chunk_indices).values[: ] = block
798+ live_mask_array [ chunk_indices ] = block
725799
726800 # Free block immediately after writing
727801 del block
728802
729803 # Force garbage collection periodically to free memory aggressively
730804 gc .collect ()
731805
732- live_mask_array .to_mdio (store = mdio_path_or_buffer , mode = "r+" )
806+ # Save the entire dataset to persist the live_mask changes
807+ ds .to_mdio (store = mdio_path_or_buffer , mode = "r+" )
733808
734809 # Final cleanup
735810 del live_mask_array
@@ -739,9 +814,13 @@ def segy_to_mdio_schematized(
739814 da = ds .seismic # TODO: Yolo the seismic Variable
740815 da .__class__ = MDIODataArray
741816
817+ header_array = ds .headers
818+ header_array .__class__ = MDIODataArray
819+
742820 stats = blocked_io .to_zarr (
743821 segy_file = segy ,
744822 grid = grid ,
745823 data_array = da ,
824+ header_array = header_array ,
746825 mdio_path_or_buffer = mdio_path_or_buffer ,
747826 )
0 commit comments