1111import zarr
1212from dask .array .core import normalize_chunks
1313from dask .array .rechunk import _balance_chunksizes
14+ from zarr import Array
1415
1516from mdio .constants import INT32_MAX
1617from mdio .constants import UINT32_MAX
@@ -35,6 +36,8 @@ class Grid:
3536 """
3637
3738 dims : list [Dimension ]
39+ map : Array | None = None
40+ live_mask : Array | None = None
3841
3942 def __post_init__ (self ):
4043 """Initialize convenience properties."""
@@ -94,27 +97,21 @@ def build_map(self, index_headers):
9497 dim_hdr = index_headers [dim .name ]
9598 live_dim_indices += (np .searchsorted (dim , dim_hdr ),)
9699
97- # There were cases where ingestion would overflow a signed int32.
98- # It's unlikely that we overflow the uint32_max, but this helps
99- # prevent any issues while keeping the memory footprint as low as possible.
100+ # Determine the appropriate data type for the map based on grid size
100101 grid_size = np .prod (self .shape [:- 1 ])
101- if grid_size > UINT32_MAX - 1 :
102- # We use UINT32_MAX-1 to ensure that the assumption below is not violated.
103- # "far away" is relative.
102+ use_uint64 = grid_size > UINT32_MAX - 1
103+ dtype = "uint64" if use_uint64 else "uint32"
104+ fill_value = UINT64_MAX if use_uint64 else UINT32_MAX
105+
106+ if use_uint64 :
104107 logging .warning (
105- f"Grid size { grid_size } exceeds UINT32_MAX ( { UINT32_MAX - 1 } ) . "
106- "Using uint64 for trace map which will use more memory."
108+ f"Grid size { grid_size } exceeds threshold { UINT32_MAX - 1 } . "
109+ "Using uint64 for trace map, which increases memory usage ."
107110 )
108- dtype = "uint64"
109- fill_value = UINT64_MAX
110- else :
111- dtype = "uint32"
112- fill_value = UINT32_MAX
113-
114- # We set dead traces to max uint32/uint64 value.
115- # Should be far away from actual trace counts.
111+
112+ # Create map of trace indices and a bool mask for live traces
116113 self .map = zarr .full (self .shape [:- 1 ], dtype = dtype , fill_value = fill_value )
117- self .map .vindex [live_dim_indices ] = range (len (live_dim_indices [0 ]))
114+ self .map .vindex [live_dim_indices ] = np . arange (len (live_dim_indices [0 ]))
118115
119116 self .live_mask = zarr .zeros (self .shape [:- 1 ], dtype = "bool" )
120117 self .live_mask .vindex [live_dim_indices ] = 1
0 commit comments