Skip to content

Commit e843895

Browse files
committed
refactor auto chunking and optimize live mask and grid map creation
1 parent 42ad30b commit e843895

File tree

3 files changed

+86
-88
lines changed

3 files changed

+86
-88
lines changed

src/mdio/core/factory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from mdio import MDIOWriter
3535
from mdio.api.io_utils import process_url
3636
from mdio.core import Grid
37-
from mdio.core.grid import _calculate_live_mask_chunksize
37+
from mdio.core.utils_write import get_live_mask_chunksize
3838
from mdio.core.utils_write import write_attribute
3939
from mdio.segy.helpers_segy import create_zarr_hierarchy
4040

@@ -146,10 +146,12 @@ def create_empty(
146146
write_attribute(name="text_header", zarr_group=meta_group, attribute=DEFAULT_TEXT)
147147
write_attribute(name="binary_header", zarr_group=meta_group, attribute={})
148148

149+
live_shape = config.grid.shape[:-1]
150+
live_chunks = get_live_mask_chunksize(live_shape)
149151
meta_group.create_dataset(
150152
name="live_mask",
151-
shape=config.grid.shape[:-1],
152-
chunks=_calculate_live_mask_chunksize(config.grid),
153+
shape=live_shape,
154+
chunks=live_chunks,
153155
dtype="bool",
154156
dimension_separator="/",
155157
)

src/mdio/core/grid.py

Lines changed: 38 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,17 @@
33
from __future__ import annotations
44

55
import inspect
6-
import logging
7-
from collections.abc import Sequence
86
from dataclasses import dataclass
97

108
import numpy as np
119
import zarr
12-
from dask.array.core import normalize_chunks
13-
from dask.array.rechunk import _balance_chunksizes
14-
from zarr import Array
10+
from segy.arrays import HeaderArray
11+
from zarr import Array as ZarrArray
1512

16-
from mdio.constants import INT32_MAX
1713
from mdio.constants import UINT32_MAX
18-
from mdio.constants import UINT64_MAX
1914
from mdio.core import Dimension
2015
from mdio.core.serialization import Serializer
16+
from mdio.core.utils_write import get_constrained_chunksize
2117

2218

2319
@dataclass
@@ -32,12 +28,13 @@ class Grid:
3228
3329
Args:
3430
dims: List of dimension instances.
35-
3631
"""
3732

3833
dims: list[Dimension]
39-
map: Array | None = None
40-
live_mask: Array | None = None
34+
map: ZarrArray | None = None
35+
live_mask: ZarrArray | None = None
36+
37+
_TARGET_MEMORY_PER_BATCH = 1 * 1024**3 # 1GB limit for
4138

4239
def __post_init__(self):
4340
"""Initialize convenience properties."""
@@ -86,35 +83,46 @@ def from_zarr(cls, zarr_root: zarr.Group):
8683

8784
return cls(dims_list)
8885

89-
def build_map(self, index_headers):
86+
def build_map(self, index_headers: HeaderArray) -> None:
9087
"""Build a map for live traces based on `index_headers`.
9188
9289
Args:
9390
index_headers: Headers to be normalized (indexed)
9491
"""
95-
live_dim_indices = tuple()
96-
for dim in self.dims[:-1]:
97-
dim_hdr = index_headers[dim.name]
98-
live_dim_indices += (np.searchsorted(dim, dim_hdr),)
99-
100-
# Determine the appropriate data type for the map based on grid size
92+
# Determine data type for the map based on grid size
10193
grid_size = np.prod(self.shape[:-1])
102-
use_uint64 = grid_size > UINT32_MAX - 1
103-
dtype = "uint64" if use_uint64 else "uint32"
104-
fill_value = UINT64_MAX if use_uint64 else UINT32_MAX
94+
map_dtype = "uint64" if grid_size > UINT32_MAX else "uint32"
95+
fill_value = np.iinfo(map_dtype).max
96+
97+
# Initialize Zarr arrays for the map and live mask
98+
live_shape = self.shape[:-1]
99+
chunks = get_constrained_chunksize(live_shape, map_dtype, 10 * 1024**2)
100+
self.map = zarr.full(live_shape, fill_value, dtype=map_dtype, chunks=chunks)
101+
self.live_mask = zarr.zeros(live_shape, dtype="bool", chunks=chunks)
102+
103+
# Calculate batch size for processing
104+
memory_per_trace_index = index_headers.itemsize
105+
batch_size = int(self._TARGET_MEMORY_PER_BATCH / memory_per_trace_index)
106+
total_live_traces = index_headers.size # Total live traces
107+
108+
# Process live traces in batches
109+
for start in range(0, total_live_traces, batch_size):
110+
end = min(start + batch_size, total_live_traces)
105111

106-
if use_uint64:
107-
logging.warning(
108-
f"Grid size {grid_size} exceeds threshold {UINT32_MAX - 1}. "
109-
"Using uint64 for trace map, which increases memory usage."
110-
)
112+
# Compute indices for the current batch
113+
live_dim_indices = []
114+
for dim in self.dims[:-1]:
115+
dim_hdr = index_headers[dim.name][start:end]
116+
indices = np.searchsorted(dim, dim_hdr).astype(np.uint32) # Use uint32
117+
live_dim_indices.append(indices)
118+
live_dim_indices = tuple(live_dim_indices)
111119

112-
# Create map of trace indices and a bool mask for live traces
113-
self.map = zarr.full(self.shape[:-1], dtype=dtype, fill_value=fill_value)
114-
self.map.vindex[live_dim_indices] = np.arange(len(live_dim_indices[0]))
120+
# Generate trace indices for the batch
121+
trace_indices = np.arange(start, end, dtype=np.uint64)
115122

116-
self.live_mask = zarr.zeros(self.shape[:-1], dtype="bool")
117-
self.live_mask.vindex[live_dim_indices] = 1
123+
# Update Zarr arrays for the batch
124+
self.map.vindex[live_dim_indices] = trace_indices
125+
self.live_mask.vindex[live_dim_indices] = True
118126

119127

120128
class GridSerializer(Serializer):
@@ -135,56 +143,3 @@ def deserialize(self, stream: str) -> Grid:
135143
payload = self.validate_payload(payload, signature)
136144

137145
return Grid(**payload)
138-
139-
140-
class _EmptyGrid:
141-
"""Empty volume for Grid mocking."""
142-
143-
def __init__(self, shape: Sequence[int], dtype: np.dtype = np.bool):
144-
"""Initialize the empty grid."""
145-
self.shape = shape
146-
self.dtype = dtype
147-
148-
def __getitem__(self, item):
149-
"""Get item from the empty grid."""
150-
return self.dtype.type(0)
151-
152-
153-
def _calculate_live_mask_chunksize(grid: Grid) -> Sequence[int]:
154-
"""Calculate the optimal chunksize for the live mask.
155-
156-
Args:
157-
grid: The grid to calculate the chunksize for.
158-
159-
Returns:
160-
A sequence of integers representing the optimal chunk size for each dimension
161-
of the grid.
162-
"""
163-
try:
164-
return _calculate_optimal_chunksize(grid.live_mask, INT32_MAX // 4)
165-
except AttributeError:
166-
# Create an empty array with the same shape and dtype as the live mask would have
167-
return _calculate_optimal_chunksize(_EmptyGrid(grid.shape[:-1]), INT32_MAX // 4)
168-
169-
170-
def _calculate_optimal_chunksize( # noqa: C901
171-
volume: np.ndarray | zarr.Array, max_bytes: int
172-
) -> Sequence[int]:
173-
"""Calculate the optimal chunksize for an N-dimensional data volume.
174-
175-
Args:
176-
volume: The volume to calculate the chunksize for.
177-
max_bytes: The maximum allowed number of bytes per chunk.
178-
179-
Returns:
180-
A sequence of integers representing the optimal chunk size for each dimension
181-
of the grid.
182-
"""
183-
shape = volume.shape
184-
chunks = normalize_chunks(
185-
"auto",
186-
shape,
187-
dtype=volume.dtype,
188-
limit=max_bytes,
189-
)
190-
return tuple(_balance_chunksizes(chunk)[0] for chunk in chunks)

src/mdio/core/utils_write.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
"""Convenience utilities for writing to Zarr."""
22

3+
from typing import TYPE_CHECKING
34
from typing import Any
45

5-
import zarr
6+
from dask.array.core import normalize_chunks
7+
from dask.array.rechunk import _balance_chunksizes
68

79

8-
def write_attribute(name: str, attribute: Any, zarr_group: zarr.Group) -> None:
10+
if TYPE_CHECKING:
11+
from numpy.typing import DTypeLike
12+
from zarr import Group
13+
14+
15+
MAX_SIZE_LIVE_MASK = 512 * 1024**2
16+
17+
18+
def write_attribute(name: str, attribute: Any, zarr_group: "Group") -> None:
919
"""Write a mappable to Zarr array or group attribute.
1020
1121
Args:
@@ -14,3 +24,34 @@ def write_attribute(name: str, attribute: Any, zarr_group: zarr.Group) -> None:
1424
zarr_group: Output group or array.
1525
"""
1626
zarr_group.attrs[name] = attribute
27+
28+
29+
def get_constrained_chunksize(
30+
shape: tuple[int, ...],
31+
dtype: "DTypeLike",
32+
max_bytes: int,
33+
) -> tuple[int]:
34+
"""Calculate the optimal chunk size for N-D array based on max_bytes.
35+
36+
Args:
37+
shape: The shape of the array.
38+
dtype: The data dtype to be used in calculation.
39+
max_bytes: The maximum allowed number of bytes per chunk.
40+
41+
Returns:
42+
A sequence of integers of calculated chunk sizes.
43+
"""
44+
chunks = normalize_chunks("auto", shape, dtype=dtype, limit=max_bytes)
45+
return tuple(_balance_chunksizes(chunk)[0] for chunk in chunks)
46+
47+
48+
def get_live_mask_chunksize(shape: tuple[int, ...]) -> tuple[int]:
49+
"""Given a live_mask shape, calculate the optimal write chunk size.
50+
51+
Args:
52+
shape: The shape of the array.
53+
54+
Returns:
55+
A sequence of integers of calculated chunk sizes.
56+
"""
57+
return get_constrained_chunksize(shape, "bool", MAX_SIZE_LIVE_MASK)

0 commit comments

Comments
 (0)