Skip to content

Commit 8c7a346

Browse files
committed
Begin breaking changes for v1 ingestion
1 parent 5e91042 commit 8c7a346

File tree

6 files changed

+214
-113
lines changed

6 files changed

+214
-113
lines changed

src/mdio/converters/segy.py

Lines changed: 97 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from mdio.segy.compat import mdio_segy_spec
2727
from mdio.segy.utilities import get_grid_plan
2828

29+
from mdio.core.v1.builder import MDIODatasetBuilder as MDIOBuilder
30+
from mdio.core.utils_write import get_live_mask_chunksize as live_chunks
31+
2932
if TYPE_CHECKING:
3033
from collections.abc import Sequence
3134
from pathlib import Path
@@ -384,6 +387,11 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
384387
mdio_spec_grid = mdio_spec.customize(trace_header_fields=index_fields)
385388
segy_grid = SegyFile(url=segy_path, spec=mdio_spec_grid, settings=segy_settings)
386389

390+
# print(mdio_spec_grid)
391+
# print(index_bytes)
392+
# print(index_names)
393+
# print(index_types)
394+
387395
dimensions, chunksize, index_headers = get_grid_plan(
388396
segy_file=segy_grid,
389397
return_headers=True,
@@ -394,6 +402,71 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
394402
grid_density_qc(grid, num_traces)
395403
grid.build_map(index_headers)
396404

405+
print(f"Dimensions: {dimensions}")
406+
print(f"Chunksize: {chunksize}")
407+
print(f"Index headers: {index_headers}")
408+
409+
builder = MDIOBuilder(name=mdio_path_or_buffer, attributes={"Description": "PH"})
410+
411+
for dim in dimensions:
412+
builder.add_dimension(dim.name, dim.size, data_type=str(dim.coords.dtype))
413+
414+
# TODO: This name is bad
415+
if chunksize is not None:
416+
lc = list(chunksize)[:-1]
417+
else:
418+
lc = list(grid.shape[:-1])
419+
420+
builder.add_variable(
421+
name="live_mask",
422+
data_type="bool",
423+
coordinates=["live_mask"],
424+
dimensions=[dim.name for dim in dimensions[:-1]],
425+
metadata={
426+
"chunkGrid": {
427+
"name": "regular",
428+
"configuration": {
429+
"chunkShape": live_chunks(lc)
430+
}
431+
}
432+
}
433+
)
434+
435+
print(f"Chunksize: {chunksize}")
436+
437+
if chunksize is not None:
438+
metadata = {
439+
"chunkGrid": {
440+
"name": "regular",
441+
"configuration": {
442+
"chunkShape": list(chunksize),
443+
}
444+
}
445+
}
446+
else:
447+
metadata = None
448+
449+
builder.add_variable(
450+
name="seismic",
451+
data_type="float32",
452+
coordinates=["live_mask"],
453+
metadata=metadata,
454+
)
455+
456+
ds = builder.to_mdio(store=mdio_path_or_buffer)
457+
458+
import json
459+
contract=json.loads(builder.build().json())
460+
from rich import print as rprint
461+
oc = {
462+
"metadata": contract["metadata"],
463+
"variables": contract["variables"],
464+
}
465+
rprint(oc)
466+
new_coords = {dim.name: dim.coords for dim in dimensions}
467+
ds = ds.assign_coords(new_coords)
468+
ds.to_mdio(store=mdio_path_or_buffer, mode="r+")
469+
397470
# Check grid validity by ensuring every trace's header-index is within dimension bounds
398471
valid_mask = np.ones(grid.num_traces, dtype=bool)
399472
for d_idx in range(len(grid.header_index_arrays)):
@@ -413,57 +486,18 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
413486
del valid_mask
414487
gc.collect()
415488

416-
if chunksize is None:
417-
dim_count = len(index_names) + 1
418-
if dim_count == 2: # noqa: PLR2004
419-
chunksize = (512,) * 2
420-
421-
elif dim_count == 3: # noqa: PLR2004
422-
chunksize = (64,) * 3
423-
424-
else:
425-
msg = (
426-
f"Default chunking for {dim_count}-D seismic data is not implemented yet. "
427-
"Please explicity define chunk sizes."
428-
)
429-
raise NotImplementedError(msg)
430-
431-
suffix = [str(x) for x in range(dim_count)]
432-
suffix = "".join(suffix)
433-
else:
434-
suffix = [dim_chunks if dim_chunks > 0 else None for dim_chunks in chunksize]
435-
suffix = [str(idx) for idx, value in enumerate(suffix) if value is not None]
436-
suffix = "".join(suffix)
437-
438-
compressors = get_compressor(lossless, compression_tolerance)
439-
header_dtype = segy.spec.trace.header.dtype.newbyteorder("=")
440-
var_conf = MDIOVariableConfig(
441-
name=f"chunked_{suffix}",
442-
dtype="float32",
443-
chunks=chunksize,
444-
compressors=compressors,
445-
header_dtype=header_dtype,
446-
)
447-
config = MDIOCreateConfig(path=mdio_path_or_buffer, grid=grid, variables=[var_conf])
448-
449-
root_group = create_empty(
450-
config,
451-
overwrite=overwrite,
452-
storage_options=storage_options_output,
453-
consolidate_meta=False,
454-
)
455-
data_group = root_group["data"]
456-
meta_group = root_group["metadata"]
457-
data_array = data_group[f"chunked_{suffix}"]
458-
header_array = meta_group[f"chunked_{suffix}_trace_headers"]
489+
live_mask_array = ds.live_mask
490+
# Cast to MDIODataArray to access the to_mdio method
491+
from mdio.core.v1._overloads import MDIODataArray
492+
live_mask_array.__class__ = MDIODataArray
459493

460-
live_mask_array = meta_group["live_mask"]
461-
# 'live_mask_array' has the same first N–1 dims as 'grid.shape[:-1]'
462494
# Build a ChunkIterator over the live_mask (no sample axis)
463495
from mdio.core.indexing import ChunkIterator
464496

465-
chunker = ChunkIterator(live_mask_array, chunk_samples=True)
497+
# chunker = ChunkIterator(live_mask_array, chunk_samples=True)
498+
chunker = ChunkIterator(ds.live_mask, chunk_samples=True)
466499
for chunk_indices in chunker:
500+
print(f"chunk_indices: {chunk_indices}")
467501
# chunk_indices is a tuple of N–1 slice objects
468502
trace_ids = grid.get_traces_for_chunk(chunk_indices)
469503
if trace_ids.size == 0:
@@ -502,35 +536,42 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
502536
del local_coords
503537

504538
# Write the entire block to Zarr at once
505-
live_mask_array.set_basic_selection(selection=chunk_indices, value=block)
539+
live_mask_array.loc[chunk_indices] = block
506540

507541
# Free block immediately after writing
508542
del block
509543

510544
# Force garbage collection periodically to free memory aggressively
511545
gc.collect()
512546

547+
live_mask_array.to_mdio(store=mdio_path_or_buffer, mode="r+")
548+
513549
# Final cleanup
514550
del live_mask_array
515551
del chunker
516552
gc.collect()
517553

518-
nonzero_count = grid.num_traces
554+
# nonzero_count = grid.num_traces
555+
556+
# write_attribute(name="trace_count", zarr_group=root_group, attribute=nonzero_count)
557+
# write_attribute(name="text_header", zarr_group=meta_group, attribute=text_header.split("\n"))
558+
# write_attribute(name="binary_header", zarr_group=meta_group, attribute=binary_header.to_dict())
519559

520-
write_attribute(name="trace_count", zarr_group=root_group, attribute=nonzero_count)
521-
write_attribute(name="text_header", zarr_group=meta_group, attribute=text_header.split("\n"))
522-
write_attribute(name="binary_header", zarr_group=meta_group, attribute=binary_header.to_dict())
560+
da = ds.seismic
561+
da.__class__ = MDIODataArray
523562

524563
# Write traces
525564
stats = blocked_io.to_zarr(
526565
segy_file=segy,
527566
grid=grid,
528-
data_array=data_array,
529-
header_array=header_array,
567+
# data_array=ds.seismic,
568+
data_array=da,
569+
# header_array=header_array,
570+
mdio_path_or_buffer=mdio_path_or_buffer,
530571
)
531572

532573
# Write actual stats
533-
for key, value in stats.items():
534-
write_attribute(name=key, zarr_group=root_group, attribute=value)
574+
# for key, value in stats.items():
575+
# write_attribute(name=key, zarr_group=root_group, attribute=value)
535576

536-
zarr.consolidate_metadata(root_group.store)
577+
# zarr.consolidate_metadata(root_group.store)

src/mdio/core/indexing.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
from zarr import Array
88

9+
import xarray as xr
10+
911

1012
class ChunkIterator:
1113
"""Iterator for traversing a Zarr array in chunks.
@@ -36,9 +38,21 @@ class ChunkIterator:
3638
...
3739
"""
3840

39-
def __init__(self, array: Array, chunk_samples: bool = True):
40-
self.arr_shape = array.shape
41-
self.len_chunks = array.chunks
41+
def __init__(self, array: Array | xr.DataArray, chunk_samples: bool = True):
42+
43+
if isinstance(array, xr.DataArray):
44+
self.arr_shape = array.shape
45+
self.len_chunks = array.encoding.get("chunks", self.arr_shape) # TODO: Chunks don't appear to be present in the encoding. array.chunks is related to dask chunks.
46+
47+
print(f"arr_shape: {self.arr_shape}")
48+
print(f"len_chunks: {self.len_chunks}")
49+
50+
print(f"array.encoding: {array.encoding}")
51+
print(f"array.chunksizes: {array.chunksizes}")
52+
53+
else:
54+
self.arr_shape = array.shape
55+
self.len_chunks = array.chunks
4256

4357
# If chunk_samples is False, set the last dimension's chunk size to its full extent
4458
if not chunk_samples:

src/mdio/core/v1/builder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ def _generate_encodings() -> dict:
268268
"""
269269
# TODO(Anybody, #10274): Re-enable chunk_key_encoding when supported by xarray
270270
# dimension_separator_encoding = V2ChunkKeyEncoding(separator="/").to_dict()
271+
272+
# Collect dimension sizes (same approach as _construct_mdio_dataset)
273+
dims: dict[str, int] = {}
274+
for var in mdio_ds.variables:
275+
for d in var.dimensions:
276+
if isinstance(d, NamedDimension):
277+
dims[d.name] = d.size
278+
271279
global_encodings = {}
272280
for var in mdio_ds.variables:
273281
fill_value = 0
@@ -276,6 +284,10 @@ def _generate_encodings() -> dict:
276284
chunks = None
277285
if var.metadata is not None and var.metadata.chunk_grid is not None:
278286
chunks = var.metadata.chunk_grid.configuration.chunk_shape
287+
else:
288+
# When no chunk_grid is provided, set chunks to shape to avoid chunking
289+
dim_names = [d.name if isinstance(d, NamedDimension) else d for d in var.dimensions]
290+
chunks = tuple(dims[name] for name in dim_names)
279291
global_encodings[var.name] = {
280292
"chunks": chunks,
281293
# TODO(Anybody, #10274): Re-enable chunk_key_encoding when supported by xarray

src/mdio/segy/_workers.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ def header_scan_worker(segy_file: SegyFile, trace_range: tuple[int, int]) -> Hea
5454
def trace_worker(
5555
segy_file: SegyFile,
5656
data_array: Array,
57-
metadata_array: Array,
57+
# metadata_array: Array,
5858
grid: Grid,
5959
chunk_indices: tuple[slice, ...],
60+
mdio_path_or_buffer: str,
6061
) -> tuple[Any, ...] | None:
6162
"""Worker function for multi-process enabled blocked SEG-Y I/O.
6263
@@ -89,8 +90,8 @@ def trace_worker(
8990
# Build a temporary buffer for data and metadata for this chunk
9091
chunk_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1]) + (grid.shape[-1],)
9192
tmp_data = np.zeros(chunk_shape, dtype=data_array.dtype)
92-
meta_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1])
93-
tmp_metadata = np.zeros(meta_shape, dtype=metadata_array.dtype)
93+
# meta_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1])
94+
# tmp_metadata = np.zeros(meta_shape, dtype=metadata_array.dtype)
9495

9596
# Compute local coordinates within the chunk for each trace
9697
local_coords: list[np.ndarray] = []
@@ -109,10 +110,10 @@ def trace_worker(
109110

110111
# Populate the temporary buffers
111112
tmp_data[full_idx] = samples
112-
tmp_metadata[tuple(local_coords)] = headers.view(tmp_metadata.dtype)
113+
# tmp_metadata[tuple(local_coords)] = headers.view(tmp_metadata.dtype)
113114

114115
# Flush metadata to Zarr
115-
metadata_array.set_basic_selection(selection=chunk_indices[:-1], value=tmp_metadata)
116+
# metadata_array.set_basic_selection(selection=chunk_indices[:-1], value=tmp_metadata)
116117

117118
# Determine nonzero samples and early-exit if none
118119
nonzero_mask = samples != 0
@@ -121,7 +122,15 @@ def trace_worker(
121122
return None
122123

123124
# Flush data to Zarr
124-
data_array.set_basic_selection(selection=chunk_indices, value=tmp_data)
125+
# data_array.set_basic_selection(selection=chunk_indices, value=tmp_data)
126+
127+
print(f"Writing data to the underlying array...")
128+
print(f"Chunk indices: {chunk_indices}")
129+
print(f"tmp_data shape: {tmp_data.shape}")
130+
131+
data_array.loc[chunk_indices] = tmp_data
132+
data_array.to_mdio(store=mdio_path_or_buffer, mode="r+")
133+
125134

126135
# Calculate statistics
127136
flattened_nonzero = samples[nonzero_mask]

0 commit comments

Comments
 (0)