Skip to content

Commit 5e91042

Browse files
committed
Merge branch '497_ingestion_memory' into v1_ingestion_YOLO
2 parents 006f614 + 61fd99d commit 5e91042

File tree

4 files changed

+208
-106
lines changed

4 files changed

+208
-106
lines changed

src/mdio/converters/segy.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def get_compressor(lossless: bool, compression_tolerance: float = -1) -> Blosc |
131131
return compressor
132132

133133

134-
def segy_to_mdio( # noqa: PLR0913, PLR0915
134+
def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
135135
segy_path: str | Path,
136136
mdio_path_or_buffer: str | Path,
137137
index_bytes: Sequence[int],
@@ -394,13 +394,24 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
394394
grid_density_qc(grid, num_traces)
395395
grid.build_map(index_headers)
396396

397-
# Check grid validity by comparing trace numbers
398-
if np.sum(grid.live_mask) != num_traces:
397+
# Check grid validity by ensuring every trace's header-index is within dimension bounds
398+
valid_mask = np.ones(grid.num_traces, dtype=bool)
399+
for d_idx in range(len(grid.header_index_arrays)):
400+
coords = grid.header_index_arrays[d_idx]
401+
valid_mask &= coords < grid.shape[d_idx]
402+
valid_count = int(np.count_nonzero(valid_mask))
403+
if valid_count != num_traces:
399404
for dim_name in grid.dim_names:
400-
dim_min, dim_max = grid.get_min(dim_name), grid.get_max(dim_name)
405+
dim_min = grid.get_min(dim_name)
406+
dim_max = grid.get_max(dim_name)
401407
logger.warning("%s min: %s max: %s", dim_name, dim_min, dim_max)
402408
logger.warning("Ingestion grid shape: %s.", grid.shape)
403-
raise GridTraceCountError(np.sum(grid.live_mask), num_traces)
409+
raise GridTraceCountError(valid_count, num_traces)
410+
411+
import gc
412+
413+
del valid_mask
414+
gc.collect()
404415

405416
if chunksize is None:
406417
dim_count = len(index_names) + 1
@@ -446,9 +457,66 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
446457
data_array = data_group[f"chunked_{suffix}"]
447458
header_array = meta_group[f"chunked_{suffix}_trace_headers"]
448459

449-
# Write actual live mask and metadata to empty MDIO
450-
meta_group["live_mask"][:] = grid.live_mask[:]
451-
nonzero_count = np.count_nonzero(grid.live_mask)
460+
live_mask_array = meta_group["live_mask"]
461+
# 'live_mask_array' has the same first N–1 dims as 'grid.shape[:-1]'
462+
# Build a ChunkIterator over the live_mask (no sample axis)
463+
from mdio.core.indexing import ChunkIterator
464+
465+
chunker = ChunkIterator(live_mask_array, chunk_samples=True)
466+
for chunk_indices in chunker:
467+
# chunk_indices is a tuple of N–1 slice objects
468+
trace_ids = grid.get_traces_for_chunk(chunk_indices)
469+
if trace_ids.size == 0:
470+
# Free memory immediately for empty chunks
471+
del trace_ids
472+
continue
473+
474+
# Build a temporary boolean block of shape = chunk shape
475+
block = np.zeros(tuple(sl.stop - sl.start for sl in chunk_indices), dtype=bool)
476+
477+
# Compute local coords within this block for each trace_id
478+
local_coords: list[np.ndarray] = []
479+
for dim_idx, sl in enumerate(chunk_indices):
480+
hdr_arr = grid.header_index_arrays[dim_idx]
481+
# Optimize memory usage: hdr_arr and trace_ids are already uint32,
482+
# sl.start is int, so result should naturally be int32/uint32.
483+
# Avoid unnecessary astype conversion to int64.
484+
indexed_coords = hdr_arr[trace_ids] # uint32 array
485+
local_idx = indexed_coords - sl.start # remains uint32
486+
# Free indexed_coords immediately
487+
del indexed_coords
488+
489+
# Only convert dtype if necessary for indexing (numpy requires int for indexing)
490+
if local_idx.dtype != np.intp:
491+
local_idx = local_idx.astype(np.intp)
492+
local_coords.append(local_idx)
493+
# local_idx is now owned by local_coords list, safe to continue
494+
495+
# Free trace_ids as soon as we're done with it
496+
del trace_ids
497+
498+
# Mark live cells in the temporary block
499+
block[tuple(local_coords)] = True
500+
501+
# Free local_coords immediately after use
502+
del local_coords
503+
504+
# Write the entire block to Zarr at once
505+
live_mask_array.set_basic_selection(selection=chunk_indices, value=block)
506+
507+
# Free block immediately after writing
508+
del block
509+
510+
# Force garbage collection periodically to free memory aggressively
511+
gc.collect()
512+
513+
# Final cleanup
514+
del live_mask_array
515+
del chunker
516+
gc.collect()
517+
518+
nonzero_count = grid.num_traces
519+
452520
write_attribute(name="trace_count", zarr_group=root_group, attribute=nonzero_count)
453521
write_attribute(name="text_header", zarr_group=meta_group, attribute=text_header.split("\n"))
454522
write_attribute(name="binary_header", zarr_group=meta_group, attribute=binary_header.to_dict())

src/mdio/core/grid.py

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
from typing import TYPE_CHECKING
88

99
import numpy as np
10-
import zarr
1110

12-
from mdio.constants import UINT32_MAX
1311
from mdio.core import Dimension
1412
from mdio.core.serialization import Serializer
15-
from mdio.core.utils_write import get_constrained_chunksize
1613

1714
if TYPE_CHECKING:
15+
import zarr
1816
from segy.arrays import HeaderArray
1917
from zarr import Array as ZarrArray
2018

@@ -65,6 +63,9 @@ def __post_init__(self) -> None:
6563
self.dim_names = tuple(dim.name for dim in self.dims)
6664
self.shape = tuple(dim.size for dim in self.dims)
6765
self.ndim = len(self.dims)
66+
# Prepare attributes for lazy mapping; they will be set in build_map
67+
self.header_index_arrays: tuple[np.ndarray, ...] = ()
68+
self.num_traces: int = 0
6869

6970
def __getitem__(self, item: int) -> Dimension:
7071
"""Get a dimension by index."""
@@ -106,47 +107,62 @@ def from_zarr(cls, zarr_root: zarr.Group) -> Grid:
106107
return cls(dims_list)
107108

108109
def build_map(self, index_headers: HeaderArray) -> None:
109-
"""Build trace mapping and live mask from header indices.
110+
"""Compute per-trace grid coordinates (lazy map).
111+
112+
Instead of allocating a full `self.map` and `self.live_mask`, this computes, for each trace,
113+
its integer index along each dimension (excluding the sample dimension) and stores them in
114+
`self.header_index_arrays`. The full mapping can then be derived chunkwise when writing.
110115
111116
Args:
112-
index_headers: Header array containing dimension indices.
117+
index_headers: Header array containing dimension indices (length = number of traces).
118+
"""
119+
# Number of traces in the SEG-Y
120+
self.num_traces = int(index_headers.shape[0])
121+
122+
# For each dimension except the final sample dimension, compute a 1D array of length
123+
# `num_traces` giving each trace's integer coordinate along that axis (via np.searchsorted).
124+
# Cast to uint32.
125+
idx_arrays: list[np.ndarray] = []
126+
for dim in self.dims[:-1]:
127+
hdr_vals = index_headers[dim.name] # shape: (num_traces,)
128+
coords = np.searchsorted(dim, hdr_vals) # integer indices
129+
coords = coords.astype(np.uint32)
130+
idx_arrays.append(coords)
131+
132+
# Store as a tuple so that header_index_arrays[d][i] is "trace i's index along axis d"
133+
self.header_index_arrays = tuple(idx_arrays)
134+
135+
# We no longer allocate `self.map` or `self.live_mask` here.
136+
# The full grid shape is `self.shape`, but mapping is done lazily per chunk.
137+
138+
def get_traces_for_chunk(self, chunk_slices: tuple[slice, ...]) -> np.ndarray:
139+
"""Return all trace IDs whose grid-coordinates fall inside the given chunk slices.
140+
141+
Args:
142+
chunk_slices: Tuple of slice objects, one per grid dimension. For example,
143+
(slice(i0, i1), slice(j0, j1), ...) corresponds to a single Zarr chunk
144+
in index space (excluding the sample axis).
145+
146+
Returns:
147+
A 1D NumPy array of trace indices (0-based) that lie within the hyper-rectangle defined
148+
by `chunk_slices`. If no traces fall in this chunk, returns an empty array.
113149
"""
114-
# Determine data type for map based on grid size
115-
grid_size = np.prod(self.shape[:-1], dtype=np.uint64)
116-
map_dtype = np.uint64 if grid_size > UINT32_MAX else np.uint32
117-
fill_value = np.iinfo(map_dtype).max
118-
119-
# Initialize Zarr arrays
120-
live_shape = self.shape[:-1]
121-
chunks = get_constrained_chunksize(
122-
shape=live_shape,
123-
dtype=map_dtype,
124-
max_bytes=self._INTERNAL_CHUNK_SIZE_TARGET,
125-
)
126-
self.map = zarr.full(live_shape, fill_value, dtype=map_dtype, chunks=chunks)
127-
self.live_mask = zarr.zeros(live_shape, dtype=bool, chunks=chunks)
128-
129-
# Calculate batch size
130-
memory_per_trace_index = index_headers.itemsize
131-
batch_size = max(1, int(self._TARGET_MEMORY_PER_BATCH / memory_per_trace_index))
132-
total_live_traces = index_headers.size
133-
134-
# Process headers in batches
135-
for start in range(0, total_live_traces, batch_size):
136-
end = min(start + batch_size, total_live_traces)
137-
live_dim_indices = []
138-
139-
# Compute indices for the batch
140-
for dim in self.dims[:-1]:
141-
dim_hdr = index_headers[dim.name][start:end]
142-
indices = np.searchsorted(dim, dim_hdr).astype(np.uint32)
143-
live_dim_indices.append(indices)
144-
live_dim_indices = tuple(live_dim_indices)
145-
146-
# Assign trace indices
147-
trace_indices = np.arange(start, end, dtype=np.uint64)
148-
self.map.vindex[live_dim_indices] = trace_indices
149-
self.live_mask.vindex[live_dim_indices] = True
150+
# Initialize a boolean mask over all traces (shape: (num_traces,))
151+
mask = np.ones((self.num_traces,), dtype=bool)
152+
153+
for dim_idx, sl in enumerate(chunk_slices):
154+
arr = self.header_index_arrays[dim_idx] # shape: (num_traces,)
155+
start, stop = sl.start, sl.stop
156+
if start is not None:
157+
mask &= arr >= start
158+
if stop is not None:
159+
mask &= arr < stop
160+
if not mask.any():
161+
# No traces remain after this dimension's filtering
162+
return np.empty((0,), dtype=np.uint32)
163+
164+
# Gather the trace IDs that survived all dimension tests
165+
return np.nonzero(mask)[0].astype(np.uint32)
150166

151167

152168
class GridSerializer(Serializer):

src/mdio/segy/_workers.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -77,45 +77,57 @@ def trace_worker(
7777
Returns:
7878
Partial statistics for chunk, or None
7979
"""
80-
# Special case where there are no traces inside chunk.
81-
live_subset = grid.live_mask[chunk_indices[:-1]]
82-
83-
if np.count_nonzero(live_subset) == 0:
80+
# Determine which trace IDs fall into this chunk
81+
trace_ids = grid.get_traces_for_chunk(chunk_indices[:-1])
82+
if trace_ids.size == 0:
8483
return None
8584

86-
# Let's get trace numbers from grid map using the chunk indices.
87-
seq_trace_indices = grid.map[chunk_indices[:-1]]
88-
89-
tmp_data = np.zeros(seq_trace_indices.shape + (grid.shape[-1],), dtype=data_array.dtype)
90-
tmp_metadata = np.zeros(seq_trace_indices.shape, dtype=metadata_array.dtype)
91-
92-
del grid # To save some memory
93-
94-
# Read headers and traces for block
95-
valid_indices = seq_trace_indices[live_subset]
96-
97-
traces = segy_file.trace[valid_indices.tolist()]
85+
# Read headers and traces for the selected trace IDs
86+
traces = segy_file.trace[trace_ids.tolist()]
9887
headers, samples = traces["header"], traces["data"]
9988

100-
tmp_metadata[live_subset] = headers.view(tmp_metadata.dtype)
101-
tmp_data[live_subset] = samples
102-
103-
# Flush metadata to zarr
89+
# Build a temporary buffer for data and metadata for this chunk
90+
chunk_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1]) + (grid.shape[-1],)
91+
tmp_data = np.zeros(chunk_shape, dtype=data_array.dtype)
92+
meta_shape = tuple(sli.stop - sli.start for sli in chunk_indices[:-1])
93+
tmp_metadata = np.zeros(meta_shape, dtype=metadata_array.dtype)
94+
95+
# Compute local coordinates within the chunk for each trace
96+
local_coords: list[np.ndarray] = []
97+
for dim_idx, sl in enumerate(chunk_indices[:-1]):
98+
hdr_arr = grid.header_index_arrays[dim_idx]
99+
# Optimize memory usage: hdr_arr and trace_ids are already uint32,
100+
# sl.start is int, so result should naturally be int32/uint32.
101+
# Avoid unnecessary astype conversion to int64.
102+
indexed_coords = hdr_arr[trace_ids] # uint32 array
103+
local_idx = indexed_coords - sl.start # remains uint32
104+
# Only convert dtype if necessary for indexing (numpy requires int for indexing)
105+
if local_idx.dtype != np.intp:
106+
local_idx = local_idx.astype(np.intp)
107+
local_coords.append(local_idx)
108+
full_idx = tuple(local_coords) + (slice(None),)
109+
110+
# Populate the temporary buffers
111+
tmp_data[full_idx] = samples
112+
tmp_metadata[tuple(local_coords)] = headers.view(tmp_metadata.dtype)
113+
114+
# Flush metadata to Zarr
104115
metadata_array.set_basic_selection(selection=chunk_indices[:-1], value=tmp_metadata)
105116

117+
# Determine nonzero samples and early-exit if none
106118
nonzero_mask = samples != 0
107-
nonzero_count = nonzero_mask.sum(dtype="uint32")
108-
119+
nonzero_count = int(nonzero_mask.sum())
109120
if nonzero_count == 0:
110121
return None
111122

123+
# Flush data to Zarr
112124
data_array.set_basic_selection(selection=chunk_indices, value=tmp_data)
113125

114126
# Calculate statistics
115-
tmp_data = samples[nonzero_mask]
116-
chunk_sum = tmp_data.sum(dtype="float64")
117-
chunk_sum_squares = np.square(tmp_data, dtype="float64").sum()
118-
min_val = tmp_data.min()
119-
max_val = tmp_data.max()
127+
flattened_nonzero = samples[nonzero_mask]
128+
chunk_sum = float(flattened_nonzero.sum(dtype="float64"))
129+
chunk_sum_squares = float(np.square(flattened_nonzero, dtype="float64").sum())
130+
min_val = float(flattened_nonzero.min())
131+
max_val = float(flattened_nonzero.max())
120132

121-
return nonzero_count, chunk_sum, chunk_sum_squares, min_val, max_val
133+
return (nonzero_count, chunk_sum, chunk_sum_squares, min_val, max_val)

0 commit comments

Comments
 (0)