|
7 | 7 | from typing import TYPE_CHECKING |
8 | 8 |
|
9 | 9 | import numpy as np |
10 | | -import zarr |
11 | 10 |
|
12 | | -from mdio.constants import UINT32_MAX |
13 | 11 | from mdio.core import Dimension |
14 | 12 | from mdio.core.serialization import Serializer |
15 | | -from mdio.core.utils_write import get_constrained_chunksize |
16 | 13 |
|
17 | 14 | if TYPE_CHECKING: |
| 15 | + import zarr |
18 | 16 | from segy.arrays import HeaderArray |
19 | 17 | from zarr import Array as ZarrArray |
20 | 18 |
|
@@ -65,6 +63,9 @@ def __post_init__(self) -> None: |
65 | 63 | self.dim_names = tuple(dim.name for dim in self.dims) |
66 | 64 | self.shape = tuple(dim.size for dim in self.dims) |
67 | 65 | self.ndim = len(self.dims) |
| 66 | + # Prepare attributes for lazy mapping; they will be set in build_map |
| 67 | + self.header_index_arrays: tuple[np.ndarray, ...] = () |
| 68 | + self.num_traces: int = 0 |
68 | 69 |
|
69 | 70 | def __getitem__(self, item: int) -> Dimension: |
70 | 71 | """Get a dimension by index.""" |
@@ -106,47 +107,62 @@ def from_zarr(cls, zarr_root: zarr.Group) -> Grid: |
106 | 107 | return cls(dims_list) |
107 | 108 |
|
108 | 109 | def build_map(self, index_headers: HeaderArray) -> None: |
109 | | - """Build trace mapping and live mask from header indices. |
| 110 | + """Compute per-trace grid coordinates (lazy map). |
| 111 | +
|
| 112 | + Instead of allocating a full `self.map` and `self.live_mask`, this computes, for each trace, |
| 113 | + its integer index along each dimension (excluding the sample dimension) and stores them in |
| 114 | + `self.header_index_arrays`. The full mapping can then be derived chunkwise when writing. |
110 | 115 |
|
111 | 116 | Args: |
112 | | - index_headers: Header array containing dimension indices. |
| 117 | + index_headers: Header array containing dimension indices (length = number of traces). |
| 118 | + """ |
| 119 | + # Number of traces in the SEG-Y |
| 120 | + self.num_traces = int(index_headers.shape[0]) |
| 121 | + |
| 122 | + # For each dimension except the final sample dimension, compute a 1D array of length |
| 123 | + # `num_traces` giving each trace's integer coordinate along that axis (via np.searchsorted). |
| 124 | + # Cast to uint32. |
| 125 | + idx_arrays: list[np.ndarray] = [] |
| 126 | + for dim in self.dims[:-1]: |
| 127 | + hdr_vals = index_headers[dim.name] # shape: (num_traces,) |
| 128 | + coords = np.searchsorted(dim, hdr_vals) # integer indices |
| 129 | + coords = coords.astype(np.uint32) |
| 130 | + idx_arrays.append(coords) |
| 131 | + |
| 132 | + # Store as a tuple so that header_index_arrays[d][i] is "trace i's index along axis d" |
| 133 | + self.header_index_arrays = tuple(idx_arrays) |
| 134 | + |
| 135 | + # We no longer allocate `self.map` or `self.live_mask` here. |
| 136 | + # The full grid shape is `self.shape`, but mapping is done lazily per chunk. |
| 137 | + |
| 138 | + def get_traces_for_chunk(self, chunk_slices: tuple[slice, ...]) -> np.ndarray: |
| 139 | + """Return all trace IDs whose grid-coordinates fall inside the given chunk slices. |
| 140 | +
|
| 141 | + Args: |
| 142 | + chunk_slices: Tuple of slice objects, one per grid dimension. For example, |
| 143 | + (slice(i0, i1), slice(j0, j1), ...) corresponds to a single Zarr chunk |
| 144 | + in index space (excluding the sample axis). |
| 145 | +
|
| 146 | + Returns: |
| 147 | + A 1D NumPy array of trace indices (0-based) that lie within the hyper-rectangle defined |
| 148 | + by `chunk_slices`. If no traces fall in this chunk, returns an empty array. |
113 | 149 | """ |
114 | | - # Determine data type for map based on grid size |
115 | | - grid_size = np.prod(self.shape[:-1], dtype=np.uint64) |
116 | | - map_dtype = np.uint64 if grid_size > UINT32_MAX else np.uint32 |
117 | | - fill_value = np.iinfo(map_dtype).max |
118 | | - |
119 | | - # Initialize Zarr arrays |
120 | | - live_shape = self.shape[:-1] |
121 | | - chunks = get_constrained_chunksize( |
122 | | - shape=live_shape, |
123 | | - dtype=map_dtype, |
124 | | - max_bytes=self._INTERNAL_CHUNK_SIZE_TARGET, |
125 | | - ) |
126 | | - self.map = zarr.full(live_shape, fill_value, dtype=map_dtype, chunks=chunks) |
127 | | - self.live_mask = zarr.zeros(live_shape, dtype=bool, chunks=chunks) |
128 | | - |
129 | | - # Calculate batch size |
130 | | - memory_per_trace_index = index_headers.itemsize |
131 | | - batch_size = max(1, int(self._TARGET_MEMORY_PER_BATCH / memory_per_trace_index)) |
132 | | - total_live_traces = index_headers.size |
133 | | - |
134 | | - # Process headers in batches |
135 | | - for start in range(0, total_live_traces, batch_size): |
136 | | - end = min(start + batch_size, total_live_traces) |
137 | | - live_dim_indices = [] |
138 | | - |
139 | | - # Compute indices for the batch |
140 | | - for dim in self.dims[:-1]: |
141 | | - dim_hdr = index_headers[dim.name][start:end] |
142 | | - indices = np.searchsorted(dim, dim_hdr).astype(np.uint32) |
143 | | - live_dim_indices.append(indices) |
144 | | - live_dim_indices = tuple(live_dim_indices) |
145 | | - |
146 | | - # Assign trace indices |
147 | | - trace_indices = np.arange(start, end, dtype=np.uint64) |
148 | | - self.map.vindex[live_dim_indices] = trace_indices |
149 | | - self.live_mask.vindex[live_dim_indices] = True |
| 150 | + # Initialize a boolean mask over all traces (shape: (num_traces,)) |
| 151 | + mask = np.ones((self.num_traces,), dtype=bool) |
| 152 | + |
| 153 | + for dim_idx, sl in enumerate(chunk_slices): |
| 154 | + arr = self.header_index_arrays[dim_idx] # shape: (num_traces,) |
| 155 | + start, stop = sl.start, sl.stop |
| 156 | + if start is not None: |
| 157 | + mask &= arr >= start |
| 158 | + if stop is not None: |
| 159 | + mask &= arr < stop |
| 160 | + if not mask.any(): |
| 161 | + # No traces remain after this dimension's filtering |
| 162 | + return np.empty((0,), dtype=np.uint32) |
| 163 | + |
| 164 | + # Gather the trace IDs that survived all dimension tests |
| 165 | + return np.nonzero(mask)[0].astype(np.uint32) |
150 | 166 |
|
151 | 167 |
|
152 | 168 | class GridSerializer(Serializer): |
|
0 commit comments