Skip to content

Commit d27879f

Browse files
committed
Begin testing for lazy compute of grid
1 parent 43a618d commit d27879f

File tree

4 files changed

+251
-112
lines changed

4 files changed

+251
-112
lines changed

src/mdio/converters/segy.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
354354
... grid_overrides={"HasDuplicates": True},
355355
... )
356356
"""
357+
print("Entering segy_to_mdio")
357358
index_names = index_names or [f"dim_{i}" for i in range(len(index_bytes))]
358359
index_types = index_types or ["int32"] * len(index_bytes)
359360

@@ -368,6 +369,7 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
368369
storage_options_input = storage_options_input or {}
369370
storage_options_output = storage_options_output or {}
370371

372+
print("pre-setup")
371373
# Open SEG-Y with MDIO's SegySpec. Endianness will be inferred.
372374
mdio_spec = mdio_segy_spec()
373375
segy_settings = SegySettings(storage_options=storage_options_input)
@@ -377,31 +379,43 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
377379
binary_header = segy.binary_header
378380
num_traces = segy.num_traces
379381

382+
print("pre-index")
380383
# Index the dataset using a spec that interprets the user provided index headers.
381-
index_fields = []
384+
index_fields: list[HeaderField] = []
382385
for name, byte, format_ in zip(index_names, index_bytes, index_types, strict=True):
383386
index_fields.append(HeaderField(name=name, byte=byte, format=format_))
384387
mdio_spec_grid = mdio_spec.customize(trace_header_fields=index_fields)
385388
segy_grid = SegyFile(url=segy_path, spec=mdio_spec_grid, settings=segy_settings)
386389

390+
print("pre-get_grid_plan")
387391
dimensions, chunksize, index_headers = get_grid_plan(
388392
segy_file=segy_grid,
389393
return_headers=True,
390394
chunksize=chunksize,
391395
grid_overrides=grid_overrides,
392396
)
393397
grid = Grid(dims=dimensions)
398+
print("pre-grid_density_qc")
394399
grid_density_qc(grid, num_traces)
400+
print("pre-build_map")
395401
grid.build_map(index_headers)
396402

397-
# Check grid validity by comparing trace numbers
398-
if np.sum(grid.live_mask) != num_traces:
403+
print("pre-valid_mask")
404+
# Check grid validity by ensuring every trace's header-index is within dimension bounds
405+
valid_mask = np.ones(grid.num_traces, dtype=bool)
406+
for d_idx in range(len(grid.header_index_arrays)):
407+
coords = grid.header_index_arrays[d_idx]
408+
valid_mask &= (coords < grid.shape[d_idx])
409+
valid_count = int(np.count_nonzero(valid_mask))
410+
if valid_count != num_traces:
399411
for dim_name in grid.dim_names:
400-
dim_min, dim_max = grid.get_min(dim_name), grid.get_max(dim_name)
412+
dim_min = grid.get_min(dim_name)
413+
dim_max = grid.get_max(dim_name)
401414
logger.warning("%s min: %s max: %s", dim_name, dim_min, dim_max)
402415
logger.warning("Ingestion grid shape: %s.", grid.shape)
403-
raise GridTraceCountError(np.sum(grid.live_mask), num_traces)
416+
raise GridTraceCountError(valid_count, num_traces)
404417

418+
print("pre-chunksize")
405419
if chunksize is None:
406420
dim_count = len(index_names) + 1
407421
if dim_count == 2: # noqa: PLR2004
@@ -424,6 +438,7 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
424438
suffix = [str(idx) for idx, value in enumerate(suffix) if value is not None]
425439
suffix = "".join(suffix)
426440

441+
print("pre-compressors")
427442
compressors = get_compressor(lossless, compression_tolerance)
428443
header_dtype = segy.spec.trace.header.dtype.newbyteorder("=")
429444
var_conf = MDIOVariableConfig(
@@ -435,6 +450,7 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
435450
)
436451
config = MDIOCreateConfig(path=mdio_path_or_buffer, grid=grid, variables=[var_conf])
437452

453+
print("pre-create_empty")
438454
root_group = create_empty(
439455
config,
440456
overwrite=overwrite,
@@ -446,23 +462,61 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
446462
data_array = data_group[f"chunked_{suffix}"]
447463
header_array = meta_group[f"chunked_{suffix}_trace_headers"]
448464

449-
# Write actual live mask and metadata to empty MDIO
450-
meta_group["live_mask"][:] = grid.live_mask[:]
451-
nonzero_count = np.count_nonzero(grid.live_mask)
465+
print("pre-live_mask")
466+
live_mask_array = meta_group["live_mask"]
467+
# 'live_mask_array' has the same first N–1 dims as 'grid.shape[:-1]'
468+
# Build a ChunkIterator over the live_mask (no sample axis)
469+
from mdio.core.indexing import ChunkIterator
470+
471+
chunker = ChunkIterator(live_mask_array, chunk_samples=False)
472+
for chunk_indices in chunker:
473+
# chunk_indices is a tuple of N–1 slice objects
474+
trace_ids = grid.get_traces_for_chunk(chunk_indices)
475+
if trace_ids.size == 0:
476+
continue
477+
478+
# Build a temporary boolean block of shape = chunk shape
479+
block_shape = tuple(sl.stop - sl.start for sl in chunk_indices)
480+
block = np.zeros(block_shape, dtype=bool)
481+
482+
# Compute local coords within this block for each trace_id
483+
local_coords: list[np.ndarray] = []
484+
for dim_idx, sl in enumerate(chunk_indices):
485+
hdr_arr = grid.header_index_arrays[dim_idx]
486+
local_idx = (hdr_arr[trace_ids] - sl.start).astype(int)
487+
local_coords.append(local_idx)
488+
489+
# Mark live cells in the temporary block
490+
block[tuple(local_coords)] = True
491+
492+
# Write the entire block to Zarr at once
493+
live_mask_array.set_basic_selection(selection=chunk_indices, value=block)
494+
495+
nonzero_count = grid.num_traces
496+
497+
print("pre-write_attribute")
452498
write_attribute(name="trace_count", zarr_group=root_group, attribute=nonzero_count)
453499
write_attribute(name="text_header", zarr_group=meta_group, attribute=text_header.split("\n"))
454500
write_attribute(name="binary_header", zarr_group=meta_group, attribute=binary_header.to_dict())
455501

502+
print("pre-to_zarr")
456503
# Write traces
504+
zarr_root = mdio_path_or_buffer # the same path you passed earlier to create_empty
505+
data_var = f"data/chunked_{suffix}"
506+
header_var = f"metadata/chunked_{suffix}_trace_headers"
507+
457508
stats = blocked_io.to_zarr(
458509
segy_file=segy,
459510
grid=grid,
460-
data_array=data_array,
461-
header_array=header_array,
511+
zarr_root_path=zarr_root,
512+
data_var_path=data_var,
513+
header_var_path=header_var,
462514
)
463515

516+
print("pre-write_attribute")
464517
# Write actual stats
465518
for key, value in stats.items():
466519
write_attribute(name=key, zarr_group=root_group, attribute=value)
467520

468-
zarr.consolidate_metadata(root_group.store)
521+
print("pre-consolidate_metadata")
522+
zarr.consolidate_metadata(root_group.store)

src/mdio/core/grid.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __post_init__(self) -> None:
6565
self.dim_names = tuple(dim.name for dim in self.dims)
6666
self.shape = tuple(dim.size for dim in self.dims)
6767
self.ndim = len(self.dims)
68+
# Prepare attributes for lazy mapping; they will be set in build_map
69+
self.header_index_arrays: tuple[np.ndarray, ...] = ()
70+
self.num_traces: int = 0
6871

6972
def __getitem__(self, item: int) -> Dimension:
7073
"""Get a dimension by index."""
@@ -106,47 +109,64 @@ def from_zarr(cls, zarr_root: zarr.Group) -> Grid:
106109
return cls(dims_list)
107110

108111
def build_map(self, index_headers: HeaderArray) -> None:
109-
"""Build trace mapping and live mask from header indices.
112+
"""Compute per-trace grid coordinates (lazy map).
113+
114+
Instead of allocating a full `self.map` and `self.live_mask`, this computes, for each trace,
115+
its integer index along each dimension (excluding the final sample dimension) and stores them in
116+
`self.header_index_arrays`. The full mapping can then be derived chunk-by-chunk when writing.
110117
111118
Args:
112-
index_headers: Header array containing dimension indices.
119+
index_headers: Header array containing dimension indices (length = number of traces).
120+
"""
121+
# Number of traces in the SEG-Y
122+
self.num_traces = int(index_headers.shape[0])
123+
124+
# For each dimension except the final sample dimension, compute a 1D array of length
125+
# `num_traces` giving each trace's integer coordinate along that axis (via np.searchsorted).
126+
# Cast to uint32.
127+
idx_arrays: list[np.ndarray] = []
128+
for dim in self.dims[:-1]:
129+
hdr_vals = index_headers[dim.name] # shape: (num_traces,)
130+
coords = np.searchsorted(dim, hdr_vals) # integer indices
131+
coords = coords.astype(np.uint32)
132+
idx_arrays.append(coords)
133+
134+
# Store as a tuple so that header_index_arrays[d][i] is "trace i's index along axis d"
135+
self.header_index_arrays = tuple(idx_arrays)
136+
137+
# We no longer allocate `self.map` or `self.live_mask` here.
138+
# The full grid shape is `self.shape`, but mapping is done lazily per chunk.
139+
return
140+
141+
def get_traces_for_chunk(self, chunk_slices: tuple[slice, ...]) -> np.ndarray:
142+
"""Return all trace IDs whose grid-coordinates fall inside the given chunk slices.
143+
144+
Args:
145+
chunk_slices: Tuple of slice objects, one per grid dimension. For example,
146+
(slice(i0, i1), slice(j0, j1), ...) corresponds to a single Zarr chunk
147+
in index space (excluding the sample axis).
148+
149+
Returns:
150+
A 1D NumPy array of trace indices (0-based) that lie within the hyper-rectangle defined
151+
by `chunk_slices`. If no traces fall in this chunk, returns an empty array.
113152
"""
114-
# Determine data type for map based on grid size
115-
grid_size = np.prod(self.shape[:-1], dtype=np.uint64)
116-
map_dtype = np.uint64 if grid_size > UINT32_MAX else np.uint32
117-
fill_value = np.iinfo(map_dtype).max
118-
119-
# Initialize Zarr arrays
120-
live_shape = self.shape[:-1]
121-
chunks = get_constrained_chunksize(
122-
shape=live_shape,
123-
dtype=map_dtype,
124-
max_bytes=self._INTERNAL_CHUNK_SIZE_TARGET,
125-
)
126-
self.map = zarr.full(live_shape, fill_value, dtype=map_dtype, chunks=chunks)
127-
self.live_mask = zarr.zeros(live_shape, dtype=bool, chunks=chunks)
128-
129-
# Calculate batch size
130-
memory_per_trace_index = index_headers.itemsize
131-
batch_size = max(1, int(self._TARGET_MEMORY_PER_BATCH / memory_per_trace_index))
132-
total_live_traces = index_headers.size
133-
134-
# Process headers in batches
135-
for start in range(0, total_live_traces, batch_size):
136-
end = min(start + batch_size, total_live_traces)
137-
live_dim_indices = []
138-
139-
# Compute indices for the batch
140-
for dim in self.dims[:-1]:
141-
dim_hdr = index_headers[dim.name][start:end]
142-
indices = np.searchsorted(dim, dim_hdr).astype(np.uint32)
143-
live_dim_indices.append(indices)
144-
live_dim_indices = tuple(live_dim_indices)
145-
146-
# Assign trace indices
147-
trace_indices = np.arange(start, end, dtype=np.uint64)
148-
self.map.vindex[live_dim_indices] = trace_indices
149-
self.live_mask.vindex[live_dim_indices] = True
153+
# Initialize a boolean mask over all traces (shape: (num_traces,))
154+
mask = np.ones((self.num_traces,), dtype=bool)
155+
156+
for dim_idx, sl in enumerate(chunk_slices):
157+
arr = self.header_index_arrays[dim_idx] # shape: (num_traces,)
158+
start, stop = sl.start, sl.stop
159+
if start is not None:
160+
mask &= (arr >= start)
161+
if stop is not None:
162+
mask &= (arr < stop)
163+
if not mask.any():
164+
# No traces remain after this dimension's filtering
165+
return np.empty((0,), dtype=np.uint32)
166+
167+
# Gather the trace IDs that survived all dimension tests
168+
trace_ids = np.nonzero(mask)[0].astype(np.uint32)
169+
return trace_ids
150170

151171

152172
class GridSerializer(Serializer):

src/mdio/segy/_workers.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -77,45 +77,53 @@ def trace_worker(
7777
Returns:
7878
Partial statistics for chunk, or None
7979
"""
80-
# Special case where there are no traces inside chunk.
81-
live_subset = grid.live_mask[chunk_indices[:-1]]
82-
83-
if np.count_nonzero(live_subset) == 0:
80+
from time import time
81+
start_time = time()
82+
# Determine which trace IDs fall into this chunk
83+
trace_ids = grid.get_traces_for_chunk(chunk_indices[:-1])
84+
if trace_ids.size == 0:
8485
return None
8586

86-
# Let's get trace numbers from grid map using the chunk indices.
87-
seq_trace_indices = grid.map[chunk_indices[:-1]]
88-
89-
tmp_data = np.zeros(seq_trace_indices.shape + (grid.shape[-1],), dtype=data_array.dtype)
90-
tmp_metadata = np.zeros(seq_trace_indices.shape, dtype=metadata_array.dtype)
91-
92-
del grid # To save some memory
93-
94-
# Read headers and traces for block
95-
valid_indices = seq_trace_indices[live_subset]
96-
97-
traces = segy_file.trace[valid_indices.tolist()]
87+
# Read headers and traces for the selected trace IDs
88+
traces = segy_file.trace[trace_ids.tolist()]
9889
headers, samples = traces["header"], traces["data"]
9990

100-
tmp_metadata[live_subset] = headers.view(tmp_metadata.dtype)
101-
tmp_data[live_subset] = samples
102-
103-
# Flush metadata to zarr
91+
# Build a temporary buffer for data and metadata for this chunk
92+
chunk_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1]) + (grid.shape[-1],)
93+
tmp_data = np.zeros(chunk_shape, dtype=data_array.dtype)
94+
meta_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1])
95+
tmp_metadata = np.zeros(meta_shape, dtype=metadata_array.dtype)
96+
97+
# Compute local coordinates within the chunk for each trace
98+
local_coords: list[np.ndarray] = []
99+
for dim_idx, sl in enumerate(chunk_indices[:-1]):
100+
hdr_arr = grid.header_index_arrays[dim_idx]
101+
local_idx = (hdr_arr[trace_ids] - sl.start).astype(int)
102+
local_coords.append(local_idx)
103+
full_idx = tuple(local_coords) + (slice(None),)
104+
105+
# Populate the temporary buffers
106+
tmp_data[full_idx] = samples
107+
tmp_metadata[tuple(local_coords)] = headers.view(tmp_metadata.dtype)
108+
109+
# Flush metadata to Zarr
104110
metadata_array.set_basic_selection(selection=chunk_indices[:-1], value=tmp_metadata)
105111

112+
# Determine nonzero samples and early-exit if none
106113
nonzero_mask = samples != 0
107-
nonzero_count = nonzero_mask.sum(dtype="uint32")
108-
114+
nonzero_count = int(nonzero_mask.sum())
109115
if nonzero_count == 0:
110116
return None
111117

118+
# Flush data to Zarr
112119
data_array.set_basic_selection(selection=chunk_indices, value=tmp_data)
113120

114121
# Calculate statistics
115-
tmp_data = samples[nonzero_mask]
116-
chunk_sum = tmp_data.sum(dtype="float64")
117-
chunk_sum_squares = np.square(tmp_data, dtype="float64").sum()
118-
min_val = tmp_data.min()
119-
max_val = tmp_data.max()
120-
121-
return nonzero_count, chunk_sum, chunk_sum_squares, min_val, max_val
122+
flattened_nonzero = samples[nonzero_mask]
123+
chunk_sum = float(flattened_nonzero.sum(dtype="float64"))
124+
chunk_sum_squares = float(np.square(flattened_nonzero, dtype="float64").sum())
125+
min_val = float(flattened_nonzero.min())
126+
max_val = float(flattened_nonzero.max())
127+
128+
print(f"Time taken: {time() - start_time} seconds")
129+
return (nonzero_count, chunk_sum, chunk_sum_squares, min_val, max_val)

0 commit comments

Comments
 (0)