33from __future__ import annotations
44
55import inspect
6- import logging
7- from collections .abc import Sequence
86from dataclasses import dataclass
97
108import numpy as np
119import zarr
12- from dask .array .core import normalize_chunks
13- from dask .array .rechunk import _balance_chunksizes
14- from zarr import Array
10+ from segy .arrays import HeaderArray
11+ from zarr import Array as ZarrArray
1512
16- from mdio .constants import INT32_MAX
1713from mdio .constants import UINT32_MAX
18- from mdio .constants import UINT64_MAX
1914from mdio .core import Dimension
2015from mdio .core .serialization import Serializer
16+ from mdio .core .utils_write import get_constrained_chunksize
2117
2218
2319@dataclass
@@ -32,12 +28,13 @@ class Grid:
3228
3329 Args:
3430 dims: List of dimension instances.
35-
3631 """
3732
3833 dims : list [Dimension ]
39- map : Array | None = None
40- live_mask : Array | None = None
34+ map : ZarrArray | None = None
35+ live_mask : ZarrArray | None = None
36+
37+ _TARGET_MEMORY_PER_BATCH = 1 * 1024 ** 3 # 1GB limit for
4138
4239 def __post_init__ (self ):
4340 """Initialize convenience properties."""
@@ -86,35 +83,46 @@ def from_zarr(cls, zarr_root: zarr.Group):
8683
8784 return cls (dims_list )
8885
89- def build_map (self , index_headers ) :
86+ def build_map (self , index_headers : HeaderArray ) -> None :
9087 """Build a map for live traces based on `index_headers`.
9188
9289 Args:
9390 index_headers: Headers to be normalized (indexed)
9491 """
95- live_dim_indices = tuple ()
96- for dim in self .dims [:- 1 ]:
97- dim_hdr = index_headers [dim .name ]
98- live_dim_indices += (np .searchsorted (dim , dim_hdr ),)
99-
100- # Determine the appropriate data type for the map based on grid size
92+ # Determine data type for the map based on grid size
10193 grid_size = np .prod (self .shape [:- 1 ])
102- use_uint64 = grid_size > UINT32_MAX - 1
103- dtype = "uint64" if use_uint64 else "uint32"
104- fill_value = UINT64_MAX if use_uint64 else UINT32_MAX
94+ map_dtype = "uint64" if grid_size > UINT32_MAX else "uint32"
95+ fill_value = np .iinfo (map_dtype ).max
96+
97+ # Initialize Zarr arrays for the map and live mask
98+ live_shape = self .shape [:- 1 ]
99+ chunks = get_constrained_chunksize (live_shape , map_dtype , 10 * 1024 ** 2 )
100+ self .map = zarr .full (live_shape , fill_value , dtype = map_dtype , chunks = chunks )
101+ self .live_mask = zarr .zeros (live_shape , dtype = "bool" , chunks = chunks )
102+
103+ # Calculate batch size for processing
104+ memory_per_trace_index = index_headers .itemsize
105+ batch_size = int (self ._TARGET_MEMORY_PER_BATCH / memory_per_trace_index )
106+ total_live_traces = index_headers .size # Total live traces
107+
108+ # Process live traces in batches
109+ for start in range (0 , total_live_traces , batch_size ):
110+ end = min (start + batch_size , total_live_traces )
105111
106- if use_uint64 :
107- logging .warning (
108- f"Grid size { grid_size } exceeds threshold { UINT32_MAX - 1 } . "
109- "Using uint64 for trace map, which increases memory usage."
110- )
112+ # Compute indices for the current batch
113+ live_dim_indices = []
114+ for dim in self .dims [:- 1 ]:
115+ dim_hdr = index_headers [dim .name ][start :end ]
116+ indices = np .searchsorted (dim , dim_hdr ).astype (np .uint32 ) # Use uint32
117+ live_dim_indices .append (indices )
118+ live_dim_indices = tuple (live_dim_indices )
111119
112- # Create map of trace indices and a bool mask for live traces
113- self .map = zarr .full (self .shape [:- 1 ], dtype = dtype , fill_value = fill_value )
114- self .map .vindex [live_dim_indices ] = np .arange (len (live_dim_indices [0 ]))
120+ # Generate trace indices for the batch
121+ trace_indices = np .arange (start , end , dtype = np .uint64 )
115122
116- self .live_mask = zarr .zeros (self .shape [:- 1 ], dtype = "bool" )
117- self .live_mask .vindex [live_dim_indices ] = 1
123+ # Update Zarr arrays for the batch
124+ self .map .vindex [live_dim_indices ] = trace_indices
125+ self .live_mask .vindex [live_dim_indices ] = True
118126
119127
120128class GridSerializer (Serializer ):
@@ -135,56 +143,3 @@ def deserialize(self, stream: str) -> Grid:
135143 payload = self .validate_payload (payload , signature )
136144
137145 return Grid (** payload )
138-
139-
140- class _EmptyGrid :
141- """Empty volume for Grid mocking."""
142-
143- def __init__ (self , shape : Sequence [int ], dtype : np .dtype = np .bool ):
144- """Initialize the empty grid."""
145- self .shape = shape
146- self .dtype = dtype
147-
148- def __getitem__ (self , item ):
149- """Get item from the empty grid."""
150- return self .dtype .type (0 )
151-
152-
153- def _calculate_live_mask_chunksize (grid : Grid ) -> Sequence [int ]:
154- """Calculate the optimal chunksize for the live mask.
155-
156- Args:
157- grid: The grid to calculate the chunksize for.
158-
159- Returns:
160- A sequence of integers representing the optimal chunk size for each dimension
161- of the grid.
162- """
163- try :
164- return _calculate_optimal_chunksize (grid .live_mask , INT32_MAX // 4 )
165- except AttributeError :
166- # Create an empty array with the same shape and dtype as the live mask would have
167- return _calculate_optimal_chunksize (_EmptyGrid (grid .shape [:- 1 ]), INT32_MAX // 4 )
168-
169-
170- def _calculate_optimal_chunksize ( # noqa: C901
171- volume : np .ndarray | zarr .Array , max_bytes : int
172- ) -> Sequence [int ]:
173- """Calculate the optimal chunksize for an N-dimensional data volume.
174-
175- Args:
176- volume: The volume to calculate the chunksize for.
177- max_bytes: The maximum allowed number of bytes per chunk.
178-
179- Returns:
180- A sequence of integers representing the optimal chunk size for each dimension
181- of the grid.
182- """
183- shape = volume .shape
184- chunks = normalize_chunks (
185- "auto" ,
186- shape ,
187- dtype = volume .dtype ,
188- limit = max_bytes ,
189- )
190- return tuple (_balance_chunksizes (chunk )[0 ] for chunk in chunks )
0 commit comments