Skip to content

Commit 7c4563e

Browse files
committed
Ingestion for v1 seismic, dims, and headers appear to be working. Begin on coords.
1 parent 8f7e9ed commit 7c4563e

File tree

4 files changed

+134
-52
lines changed

4 files changed

+134
-52
lines changed

src/mdio/converters/segy.py

Lines changed: 114 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Any
99

1010
import numpy as np
11-
import zarr
1211
from numcodecs import Blosc
1312
from segy import SegyFile
1413
from segy.config import SegySettings
@@ -18,17 +17,12 @@
1817
from mdio.converters.exceptions import GridTraceCountError
1918
from mdio.converters.exceptions import GridTraceSparsityError
2019
from mdio.core import Grid
21-
from mdio.core.factory import MDIOCreateConfig
22-
from mdio.core.factory import MDIOVariableConfig
23-
from mdio.core.factory import create_empty
24-
from mdio.core.utils_write import write_attribute
20+
from mdio.core.utils_write import get_live_mask_chunksize as live_chunks
21+
from mdio.core.v1.builder import MDIODatasetBuilder as MDIOBuilder
2522
from mdio.segy import blocked_io
2623
from mdio.segy.compat import mdio_segy_spec
2724
from mdio.segy.utilities import get_grid_plan
2825

29-
from mdio.core.v1.builder import MDIODatasetBuilder as MDIOBuilder
30-
from mdio.core.utils_write import get_live_mask_chunksize as live_chunks
31-
3226
if TYPE_CHECKING:
3327
from collections.abc import Sequence
3428
from pathlib import Path
@@ -423,24 +417,19 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
423417
coordinates=["live_mask"],
424418
dimensions=[dim.name for dim in dimensions[:-1]],
425419
metadata={
426-
"chunkGrid": {
427-
"name": "regular",
428-
"configuration": {
429-
"chunkShape": live_chunks(lc)
430-
}
431-
}
432-
}
420+
"chunkGrid": {"name": "regular", "configuration": {"chunkShape": live_chunks(lc)}}
421+
},
433422
)
434423

435-
print(f"Chunksize: {chunksize}")
424+
# print(f"Chunksize: {chunksize}")
436425

437426
if chunksize is not None:
438427
metadata = {
439428
"chunkGrid": {
440429
"name": "regular",
441430
"configuration": {
442431
"chunkShape": list(chunksize),
443-
}
432+
},
444433
}
445434
}
446435
else:
@@ -456,8 +445,10 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
456445
ds = builder.to_mdio(store=mdio_path_or_buffer)
457446

458447
import json
459-
contract=json.loads(builder.build().json())
448+
449+
contract = json.loads(builder.build().json())
460450
from rich import print as rprint
451+
461452
oc = {
462453
"metadata": contract["metadata"],
463454
"variables": contract["variables"],
@@ -489,6 +480,7 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
489480
live_mask_array = ds.live_mask
490481
# Cast to MDIODataArray to access the to_mdio method
491482
from mdio.core.v1._overloads import MDIODataArray
483+
492484
live_mask_array.__class__ = MDIODataArray
493485

494486
# Build a ChunkIterator over the live_mask (no sample axis)
@@ -576,40 +568,113 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
576568

577569
# zarr.consolidate_metadata(root_group.store)
578570

571+
def validate_segy_schema(segy_schema: dict[str, Any]) -> None:
572+
"""Validate the SEG-Y schema.
573+
574+
Args:
575+
segy_schema: SEG-Y schema
576+
577+
Raises:
578+
ValueError: If schema is missing required fields or has invalid structure
579+
"""
580+
if "trace" not in segy_schema:
581+
raise ValueError("SEG-Y schema must contain 'trace' field")
582+
583+
if "header_entries" not in segy_schema["trace"]:
584+
raise ValueError("SEG-Y schema trace must contain 'header_entries' field")
585+
586+
if not isinstance(segy_schema["trace"]["header_entries"], list):
587+
raise ValueError("SEG-Y schema trace header_entries must be a list")
588+
589+
590+
def get_dims(ds: MDIO, segy_schema: dict[str, Any]) -> dict[str, Any]:
591+
"""Get the dimensions of the MDIO dataset from the SEG-Y schema.
592+
593+
Args:
594+
ds: MDIO dataset
595+
segy_schema: SEG-Y schema
596+
"""
597+
target_dims = ds.seismic.dims[:-1]
598+
599+
try:
600+
validate_segy_schema(segy_schema)
601+
except ValueError as e:
602+
raise ValueError(f"Unable to parse SEG-Y schema: {e}")
603+
604+
trace_headers = segy_schema["trace"]["header_entries"]
605+
ret = {}
606+
607+
for header in trace_headers:
608+
if header["name"] in target_dims:
609+
ret[header["name"]] = {
610+
"index_name": header["name"],
611+
"index_type": header["format"],
612+
"index_byte": header["byte_start"],
613+
}
614+
615+
if len(ret) != len(target_dims):
616+
raise ValueError(f"Not all dimensions were found in the SEG-Y schema. Missing: {target_dims - ret.keys()}")
617+
618+
return ret
619+
620+
def get_sample_name(ds: MDIO, grid_dims) -> str:
621+
"""Get the name of the sample dimension from the dataset.
622+
623+
Args:
624+
ds: MDIO dataset
625+
grid_dims: List of grid dimensions
626+
627+
Returns:
628+
str: Name of the sample dimension
629+
"""
630+
ds_dims = list(ds.seismic.dims)
631+
for dim in grid_dims:
632+
try:
633+
ds_dims.remove(dim.name)
634+
except ValueError:
635+
pass
636+
return ds_dims[0] # Should only be one left
637+
579638

580639
def segy_to_mdio_schematized(
581640
segy_schema: dict[str, Any],
582-
mdio_schema: dict[str, Any],
641+
mdio_schema: dict[str, Any],
583642
mdio_path_or_buffer: str | Path,
584643
storage_options_input: dict[str, Any] | None = None,
585644
storage_options_output: dict[str, Any] | None = None,
586645
) -> None:
587646
"""Create MDIO dataset from a schema specification using Pydantic v1 models.
588-
647+
589648
Args:
590649
segy_schema: Dictionary containing SEG-Y related schema (currently unused)
591650
mdio_schema: Dictionary containing the MDIO schema specification
592651
mdio_path_or_buffer: Output path for the MDIO file
593652
"""
594-
595653
grid_overrides = None # TODO: Implement this maybe?
596654

597-
from mdio.core.v1.factory import from_contract
598655
from mdio.core.v1._overloads import MDIO
656+
from mdio.core.v1.factory import from_contract
657+
599658
serialized_mdio = from_contract(mdio_path_or_buffer, mdio_schema)
600659

601660
ds = MDIO.open(mdio_path_or_buffer) # Reopen because we needed to do some weird stuff (hacky)
602661

603-
index_names = segy_schema["index_names"]
604-
index_types = segy_schema["index_types"]
605-
index_bytes = segy_schema["index_bytes"]
662+
try:
663+
dims = get_dims(ds, segy_schema)
664+
except ValueError as e:
665+
raise ValueError(f"Unable to parse SEG-Y schema into MDIO schema: {e}")
666+
667+
index_names = [dims[dim]["index_name"] for dim in dims]
668+
index_types = [dims[dim]["index_type"] for dim in dims]
669+
index_bytes = [dims[dim]["index_byte"] for dim in dims]
670+
606671

607672
chunksize = None
608673
live_mask_valid = False
609674
for variable in mdio_schema["variables"]:
610675
if variable["name"] == "seismic":
611676
chunksize = variable["metadata"]["chunkGrid"]["configuration"]["chunkShape"]
612-
elif variable["name"] == "live_mask":
677+
elif variable["name"] == "live_mask" or variable["name"] == "trace_mask":
613678
live_mask_valid = True
614679

615680
if chunksize is None:
@@ -620,16 +685,18 @@ def segy_to_mdio_schematized(
620685

621686
storage_options_input = storage_options_input or {}
622687
storage_options_output = storage_options_output or {}
623-
624-
mdio_spec = mdio_segy_spec() # TODO: I think this may need to be updated to work with our new input schemas
688+
689+
mdio_spec = (
690+
mdio_segy_spec()
691+
) # TODO: I think this may need to be updated to work with our new input schemas
625692
segy_settings = SegySettings(storage_options=storage_options_input)
626693
# segy = SegyFile(url=segy_path, spec=mdio_spec, settings=segy_settings)
627694
segy = SegyFile(url=segy_schema["path"], spec=mdio_spec, settings=segy_settings)
628695

629696
text_header = segy.text_header
630697
binary_header = segy.binary_header
631698
num_traces = segy.num_traces
632-
699+
633700
# Index the dataset using a spec that interprets the user provided index headers.
634701
index_fields = []
635702
for name, byte, format_ in zip(index_names, index_bytes, index_types, strict=True):
@@ -643,18 +710,21 @@ def segy_to_mdio_schematized(
643710
chunksize=chunksize,
644711
grid_overrides=grid_overrides,
645712
)
713+
dimensions[-1].name = get_sample_name(ds, dimensions)
714+
# print(dimensions)
646715
grid = Grid(dims=dimensions)
647716
grid_density_qc(grid, num_traces)
648717
grid.build_map(index_headers)
649718

719+
# Override the "sample" dimension name
720+
650721
# Set dimension coordinates
651722
new_coords = {dim.name: dim.coords for dim in dimensions}
652723
ds = ds.assign_coords(new_coords)
653724
ds.to_mdio(store=mdio_path_or_buffer, mode="r+")
654725

655726
# Set all coordinates which are not dimensions, root Variables, or live_mask
656727

657-
658728
# Check grid validity by ensuring every trace's header-index is within dimension bounds
659729
valid_mask = np.ones(grid.num_traces, dtype=bool)
660730
for d_idx in range(len(grid.header_index_arrays)):
@@ -674,15 +744,19 @@ def segy_to_mdio_schematized(
674744
del valid_mask
675745
gc.collect()
676746

677-
live_mask_array = ds.live_mask # TODO: Make this more robust
747+
coords = ds.seismic.coords # TODO: We also need to iterate over the coords and assign their values in parallel with the live_mask
748+
749+
# live_mask_array = ds.live_mask # TODO: Make this more robust
750+
live_mask_array = ds.trace_mask
678751
from mdio.core.v1._overloads import MDIODataArray
752+
679753
live_mask_array.__class__ = MDIODataArray
680754

681755
from mdio.core.indexing import ChunkIterator
682756

683757
chunker = ChunkIterator(live_mask_array, chunk_samples=True)
684758
for chunk_indices in chunker:
685-
print(f"chunk_indices: {chunk_indices}")
759+
# print(f"chunk_indices: {chunk_indices}")
686760
# chunk_indices is a tuple of N–1 slice objects
687761
trace_ids = grid.get_traces_for_chunk(chunk_indices)
688762
if trace_ids.size == 0:
@@ -720,16 +794,17 @@ def segy_to_mdio_schematized(
720794
del local_coords
721795

722796
# Write the entire block to Zarr at once
723-
# live_mask_array.loc[chunk_indices] = block
724-
live_mask_array.isel(isel_dict).values[:] = block
797+
# live_mask_array.isel(chunk_indices).values[:] = block
798+
live_mask_array[chunk_indices] = block
725799

726800
# Free block immediately after writing
727801
del block
728802

729803
# Force garbage collection periodically to free memory aggressively
730804
gc.collect()
731805

732-
live_mask_array.to_mdio(store=mdio_path_or_buffer, mode="r+")
806+
# Save the entire dataset to persist the live_mask changes
807+
ds.to_mdio(store=mdio_path_or_buffer, mode="r+")
733808

734809
# Final cleanup
735810
del live_mask_array
@@ -739,9 +814,13 @@ def segy_to_mdio_schematized(
739814
da = ds.seismic # TODO: Yolo the seismic Variable
740815
da.__class__ = MDIODataArray
741816

817+
header_array = ds.headers
818+
header_array.__class__ = MDIODataArray
819+
742820
stats = blocked_io.to_zarr(
743821
segy_file=segy,
744822
grid=grid,
745823
data_array=da,
824+
header_array=header_array,
746825
mdio_path_or_buffer=mdio_path_or_buffer,
747826
)

src/mdio/core/indexing.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from math import ceil
55

66
import numpy as np
7-
from zarr import Array
8-
97
import xarray as xr
8+
from zarr import Array
109

1110

1211
class ChunkIterator:
@@ -39,10 +38,11 @@ class ChunkIterator:
3938
"""
4039

4140
def __init__(self, array: Array | xr.DataArray, chunk_samples: bool = True):
42-
4341
if isinstance(array, xr.DataArray):
4442
self.arr_shape = array.shape
45-
self.len_chunks = array.encoding.get("chunks", self.arr_shape) # TODO: Chunks don't appear to be present in the encoding. array.chunks is related to dask chunks.
43+
self.len_chunks = array.encoding.get(
44+
"chunks", self.arr_shape
45+
) # TODO: Chunks don't appear to be present in the encoding. array.chunks is related to dask chunks.
4646

4747
print(f"arr_shape: {self.arr_shape}")
4848
print(f"len_chunks: {self.len_chunks}")
@@ -95,7 +95,8 @@ def __next__(self) -> tuple[slice, ...]:
9595
)
9696

9797
stop_indices = tuple(
98-
(dim + 1) * chunk for dim, chunk in zip(current_start, self.len_chunks, strict=True)
98+
min((dim + 1) * chunk, array_size)
99+
for dim, chunk, array_size in zip(current_start, self.len_chunks, self.arr_shape, strict=True)
99100
)
100101

101102
slices = tuple(

src/mdio/segy/_workers.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
if TYPE_CHECKING:
1313
from segy import SegyFile
1414
from segy.arrays import HeaderArray
15-
from zarr import Array
1615

1716
from mdio.core import Grid
1817

@@ -55,6 +54,7 @@ def trace_worker(
5554
segy_file: SegyFile,
5655
# data_array: Array,
5756
data_array: mdio.DataArray,
57+
metadata_array: mdio.DataArray,
5858
# metadata_array: Array,
5959
grid: Grid,
6060
chunk_indices: tuple[slice, ...],
@@ -94,8 +94,8 @@ def trace_worker(
9494
print(f"Chunk shape from trace_worker: {chunk_shape}")
9595

9696
tmp_data = np.zeros(chunk_shape, dtype=data_array.dtype)
97-
# meta_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1])
98-
# tmp_metadata = np.zeros(meta_shape, dtype=metadata_array.dtype)
97+
meta_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1])
98+
tmp_metadata = np.zeros(meta_shape, dtype=metadata_array.dtype)
9999

100100
# Compute local coordinates within the chunk for each trace
101101
local_coords: list[np.ndarray] = []
@@ -114,10 +114,12 @@ def trace_worker(
114114

115115
# Populate the temporary buffers
116116
tmp_data[full_idx] = samples
117-
# tmp_metadata[tuple(local_coords)] = headers.view(tmp_metadata.dtype)
117+
tmp_metadata[tuple(local_coords)] = headers.view(tmp_metadata.dtype)
118118

119119
# Flush metadata to Zarr
120120
# metadata_array.set_basic_selection(selection=chunk_indices[:-1], value=tmp_metadata)
121+
metadata_array.data[chunk_indices[:-1]] = tmp_metadata
122+
metadata_array.to_mdio(store=mdio_path_or_buffer, mode="r+")
121123

122124
# Determine nonzero samples and early-exit if none
123125
nonzero_mask = samples != 0
@@ -128,10 +130,10 @@ def trace_worker(
128130
# Flush data to Zarr
129131
# data_array.set_basic_selection(selection=chunk_indices, value=tmp_data)
130132

131-
print(f"Writing data to the underlying array...")
132-
print(f"Chunk indices: {chunk_indices}")
133-
print(f"tmp_data shape: {tmp_data.shape}")
134-
print(f"data_array shape: {data_array.shape}")
133+
# print("Writing data to the underlying array...")
134+
# print(f"Chunk indices: {chunk_indices}")
135+
# print(f"tmp_data shape: {tmp_data.shape}")
136+
# print(f"data_array shape: {data_array.shape}")
135137

136138
# Direct assignment to underlying data array
137139
data_array.data[chunk_indices] = tmp_data

0 commit comments

Comments
 (0)