Skip to content

Commit 6de69b0

Browse files
committed
simplify dtype determination logic
1 parent 0807f82 commit 6de69b0

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

src/mdio/core/grid.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import zarr
1212
from dask.array.core import normalize_chunks
1313
from dask.array.rechunk import _balance_chunksizes
14+
from zarr import Array
1415

1516
from mdio.constants import INT32_MAX
1617
from mdio.constants import UINT32_MAX
@@ -35,6 +36,8 @@ class Grid:
3536
"""
3637

3738
dims: list[Dimension]
39+
map: Array | None = None
40+
live_mask: Array | None = None
3841

3942
def __post_init__(self):
4043
"""Initialize convenience properties."""
@@ -94,27 +97,21 @@ def build_map(self, index_headers):
9497
dim_hdr = index_headers[dim.name]
9598
live_dim_indices += (np.searchsorted(dim, dim_hdr),)
9699

97-
# There were cases where ingestion would overflow a signed int32.
98-
# It's unlikely that we overflow the uint32_max, but this helps
99-
# prevent any issues while keeping the memory footprint as low as possible.
100+
# Determine the appropriate data type for the map based on grid size
100101
grid_size = np.prod(self.shape[:-1])
101-
if grid_size > UINT32_MAX - 1:
102-
# We use UINT32_MAX-1 to ensure that the assumption below is not violated.
103-
# "far away" is relative.
102+
use_uint64 = grid_size > UINT32_MAX - 1
103+
dtype = "uint64" if use_uint64 else "uint32"
104+
fill_value = UINT64_MAX if use_uint64 else UINT32_MAX
105+
106+
if use_uint64:
104107
logging.warning(
105-
f"Grid size {grid_size} exceeds UINT32_MAX ({UINT32_MAX - 1}). "
106-
"Using uint64 for trace map which will use more memory."
108+
f"Grid size {grid_size} exceeds threshold {UINT32_MAX - 1}. "
109+
"Using uint64 for trace map, which increases memory usage."
107110
)
108-
dtype = "uint64"
109-
fill_value = UINT64_MAX
110-
else:
111-
dtype = "uint32"
112-
fill_value = UINT32_MAX
113-
114-
# We set dead traces to max uint32/uint64 value.
115-
# Should be far away from actual trace counts.
111+
112+
# Create map of trace indices and a bool mask for live traces
116113
self.map = zarr.full(self.shape[:-1], dtype=dtype, fill_value=fill_value)
117-
self.map.vindex[live_dim_indices] = range(len(live_dim_indices[0]))
114+
self.map.vindex[live_dim_indices] = np.arange(len(live_dim_indices[0]))
118115

119116
self.live_mask = zarr.zeros(self.shape[:-1], dtype="bool")
120117
self.live_mask.vindex[live_dim_indices] = 1

0 commit comments

Comments
 (0)