Skip to content

Commit 5d01ffe

Browse files
committed
Rework chunking computation and expand test coverage
1 parent e7cb830 commit 5d01ffe

File tree

2 files changed

+221
-68
lines changed

2 files changed

+221
-68
lines changed

src/mdio/converters/segy.py

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -506,35 +506,125 @@ def _calculate_live_mask_chunksize(grid: Grid) -> Sequence[int]:
506506
grid: The grid to calculate the chunksize for.
507507
508508
Returns:
509-
A sequence of integers representing the optimal chunk size for each dimension of the grid.
509+
A sequence of integers representing the optimal chunk size for each dimension
510+
of the grid.
510511
"""
511-
# Use nbytes for the initial check, since we are limited by Blosc's maximum
512-
# chunk size in bytes.
513-
if grid.live_mask.nbytes < INT32_MAX:
514-
# Base case where we don't need to chunk the live mask
515-
return grid.live_mask.shape
516-
517-
# Calculate the optimal chunksize for the live mask
518-
total_elements = np.prod(grid.shape[:-1]) # Exclude sample dimension
519-
num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)
520-
521-
# Calculate chunk size for each dimension
522-
chunks = []
523-
remaining_elements = total_elements
524-
525-
for dim_size in grid.shape[:-1]: # Exclude sample dimension
526-
# Calculate how many chunks we need in this dimension
527-
# We want to distribute chunks evenly across dimensions
528-
dim_chunks = max(
529-
1,
530-
int(
531-
np.ceil(
532-
dim_size / np.ceil(np.power(num_chunks, 1 / len(grid.shape[:-1])))
533-
)
534-
),
535-
)
536-
chunk_size = int(np.ceil(dim_size / dim_chunks))
537-
chunks.append(chunk_size)
538-
remaining_elements //= dim_chunks
512+
return _calculate_optimal_chunksize(grid.live_mask, INT32_MAX)
513+
514+
515+
def _calculate_optimal_chunksize( # noqa: C901
516+
volume: np.ndarray | zarr.Array, n_bytes: int
517+
) -> Sequence[int]:
518+
"""Calculate a uniform chunk shape for an N-dimensional data volume such that...
519+
520+
0. The product of the chunk dimensions multiplied by the element size does not
521+
exceed n_bytes.
522+
1. The chunk shape is "regular" – each chunk dimension is a divisor of the
523+
overall volume shape.
524+
2. If an exact match is impossible, the chunk shape chosen maximizes the number of
525+
elements (minimizing the unused bytes).
526+
3. The computation is efficient.
527+
528+
The computation efficiency is broken down as follows:
529+
530+
- Divisor Computation: For each of the N dimensions (assume size ~ n), it checks
531+
up to n numbers, so this part is roughly O(N * n).
532+
For example, if you have a 3D array where each dimension is about 100,
533+
it does around 3*100 = 300 steps.
534+
- DFS Search: In the worst-case, the DFS explores about D choices per dimension
535+
(D = average number of divisors) leading to O(D^N) combinations.
536+
In practice, D is small (often < 10), so for a 2D array this is around 10^2
537+
(about 100 combinations) and for a 3D array about 10^3 (roughly 1,000 combinations).
538+
Since N is typically small (often <6), this exponential term behaves like a
539+
constant factor.
540+
541+
Args:
542+
volume : np.ndarray | zarr.Array
543+
An N-dimensional array-like object (e.g. np.ndarray or zarr array).
544+
n_bytes : int
545+
Maximum allowed number of bytes per chunk (>= 1).
539546
540-
return tuple(chunks)
547+
Returns:
548+
Sequence[int]
549+
A tuple representing the optimal chunk shape (number of elements along each axis).
550+
551+
Raises:
552+
ValueError if n_bytes is less than the number of bytes of one element.
553+
"""
554+
# Get volume shape and element size.
555+
shape = volume.shape
556+
557+
if volume.size == 0:
558+
logging.warning("Chunking calculation received empty volume shape...")
559+
return volume.shape
560+
561+
itemsize = volume.dtype.itemsize
562+
563+
# Maximum number of elements that can fit in a chunk
564+
# (we ignore any extra bytes; must not exceed n_bytes).
565+
max_elements_allowed = n_bytes // itemsize
566+
if max_elements_allowed < 1:
567+
raise ValueError("n_bytes is too small to hold even one element of the volume.")
568+
569+
n_dims = len(shape)
570+
571+
def get_divisors(n: int) -> list[int]:
572+
"""Return a sorted list of all positive divisors of n.
573+
574+
Args:
575+
n: The number to compute the divisors of.
576+
577+
Returns:
578+
A sorted list of all positive divisors of n.
579+
"""
580+
divs = []
581+
# It is efficient enough for typical dimension sizes.
582+
for i in range(1, n + 1):
583+
if n % i == 0:
584+
divs.append(i)
585+
return sorted(divs)
586+
587+
# For each dimension, compute the list of allowed chunk sizes (divisors).
588+
divisors_list = [get_divisors(d) for d in shape]
589+
590+
# For pruning: precompute the maximum possible product achievable from axis i to N-1.
591+
# This is the product of the maximum divisors for each remaining axis.
592+
max_possible = [1] * (n_dims + 1)
593+
for i in range(n_dims - 1, -1, -1):
594+
max_possible[i] = max(divisors_list[i]) * max_possible[i + 1]
595+
596+
best_product = 0
597+
best_combination = [None] * n_dims
598+
current_chunk = [None] * n_dims
599+
600+
def dfs(dim: int, current_product: int) -> None:
601+
"""Depth-first search to find the optimal chunk shape.
602+
603+
Args:
604+
dim: The current dimension to process.
605+
current_product: The current product of the chunk dimensions.
606+
"""
607+
nonlocal best_product
608+
# If all dimensions have been processed, update best combination if needed.
609+
if dim == n_dims:
610+
if current_product > best_product:
611+
best_product = current_product
612+
best_combination[:] = current_chunk[:]
613+
return
614+
615+
# Prune branches: even if we take the maximum allowed for all remaining dimensions,
616+
# if we cannot exceed best_product, then skip.
617+
if current_product * max_possible[dim] < best_product:
618+
return
619+
620+
# Iterate over allowed divisors for the current axis,
621+
# trying larger candidates first so that high products are found early.
622+
for candidate in sorted(divisors_list[dim], reverse=True):
623+
new_product = current_product * candidate
624+
if new_product > max_elements_allowed:
625+
continue # This candidate would exceed the byte restriction.
626+
current_chunk[dim] = candidate
627+
dfs(dim + 1, new_product)
628+
629+
dfs(0, 1)
630+
return tuple(best_combination)

tests/unit/test_live_mask_chunksize.py

Lines changed: 101 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Test live mask chunk size calculation."""
22

33
import numpy as np
4+
import pytest
45

5-
from mdio.constants import INT32_MAX
66
from mdio.converters.segy import _calculate_live_mask_chunksize
7+
from mdio.converters.segy import _calculate_optimal_chunksize
78
from mdio.core import Dimension
89
from mdio.core import Grid
910

@@ -20,7 +21,7 @@ def test_small_grid_no_chunking():
2021
grid.live_mask = np.ones((100, 100), dtype=bool)
2122

2223
result = _calculate_live_mask_chunksize(grid)
23-
assert result == -1
24+
assert result == (100, 100)
2425

2526

2627
def test_large_2d_grid_chunking():
@@ -37,13 +38,8 @@ def test_large_2d_grid_chunking():
3738

3839
result = _calculate_live_mask_chunksize(grid)
3940

40-
# Calculate expected values
41-
total_elements = 50000 * 50000
42-
num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)
43-
dim_chunks = int(np.ceil(50000 / np.ceil(np.power(num_chunks, 1 / 2))))
44-
expected_chunk_size = int(np.ceil(50000 / dim_chunks))
45-
46-
assert result == (expected_chunk_size, expected_chunk_size)
41+
# TODO(BrianMichell): Avoid magic numbers.
42+
assert result == (50000, 25000)
4743

4844

4945
def test_large_3d_grid_chunking():
@@ -62,12 +58,13 @@ def test_large_3d_grid_chunking():
6258
result = _calculate_live_mask_chunksize(grid)
6359

6460
# Calculate expected values
65-
total_elements = 1500 * 1500 * 1500
66-
num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)
67-
dim_chunks = int(np.ceil(1500 / np.ceil(np.power(num_chunks, 1 / 3))))
68-
expected_chunk_size = int(np.ceil(1500 / dim_chunks))
61+
# total_elements = 1500 * 1500 * 1500
62+
# num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)
63+
# dim_chunks = int(np.ceil(1500 / np.ceil(np.power(num_chunks, 1 / 3))))
64+
# expected_chunk_size = int(np.ceil(1500 / dim_chunks))
6965

70-
assert result == (expected_chunk_size, expected_chunk_size, expected_chunk_size)
66+
# assert result == (expected_chunk_size, expected_chunk_size, expected_chunk_size)
67+
assert result == (1500, 1500, 750)
7168

7269

7370
def test_uneven_dimensions_chunking():
@@ -84,14 +81,7 @@ def test_uneven_dimensions_chunking():
8481
grid.live_mask = np.ones((50000, 50000), dtype=bool)
8582

8683
result = _calculate_live_mask_chunksize(grid)
87-
88-
# Calculate expected values
89-
total_elements = 50000 * 50000
90-
num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)
91-
dim_chunks = int(np.ceil(50000 / np.ceil(np.power(num_chunks, 1 / 2))))
92-
expected_chunk_size = int(np.ceil(50000 / dim_chunks))
93-
94-
assert result == (expected_chunk_size, expected_chunk_size)
84+
assert result == (50000, 25000)
9585

9686

9787
def test_prestack_land_survey_chunking():
@@ -114,21 +104,7 @@ def test_prestack_land_survey_chunking():
114104
grid.live_mask = np.ones((1000, 1000, 100, 36), dtype=bool)
115105

116106
result = _calculate_live_mask_chunksize(grid)
117-
118-
# Calculate expected values
119-
total_elements = 1000 * 1000 * 100 * 36
120-
num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)
121-
dim_chunks = int(np.ceil(1000 / np.ceil(np.power(num_chunks, 1 / 4))))
122-
expected_chunk_size = int(np.ceil(1000 / dim_chunks))
123-
124-
# For a 4D grid, we expect chunk sizes to be distributed across all dimensions
125-
# The chunk size should be the same for all dimensions since they're all equally important
126-
assert result == (
127-
expected_chunk_size,
128-
expected_chunk_size,
129-
expected_chunk_size,
130-
expected_chunk_size,
131-
)
107+
assert result == (1000, 1000, 100, 18)
132108

133109

134110
def test_edge_case_empty_grid():
@@ -142,4 +118,91 @@ def test_edge_case_empty_grid():
142118
grid.live_mask = np.zeros((0, 0), dtype=bool)
143119

144120
result = _calculate_live_mask_chunksize(grid)
145-
assert result == -1 # Empty grid shouldn't need chunking
121+
assert result == (0, 0)
122+
123+
124+
# Additional tests for _calculate_optimal_chunksize function
125+
def test_empty_volume():
126+
"""Test that an empty volume returns its shape."""
127+
empty_arr = np.zeros((0, 10), dtype=np.int8)
128+
result = _calculate_optimal_chunksize(empty_arr, 100)
129+
assert result == (0, 10)
130+
131+
132+
def test_nbytes_too_small():
133+
"""Test that a too-small n_bytes value raises a ValueError."""
134+
arr = np.zeros((10,), dtype=np.int8) # itemsize is 1
135+
with pytest.raises(
136+
ValueError, match=r"n_bytes is too small to hold even one element"
137+
):
138+
_calculate_optimal_chunksize(arr, 0)
139+
140+
141+
def test_one_dim_full_chunk():
142+
"""Test one-dimensional volume where the whole dimension can be used as chunk."""
143+
arr = np.zeros((100,), dtype=np.int8)
144+
# With n_bytes = 100, max_elements_allowed = 100, thus optimal chunk should be (100,)
145+
result = _calculate_optimal_chunksize(arr, 100)
146+
assert result == (100,)
147+
148+
149+
def test_two_dim_optimal():
150+
"""Test two-dimensional volume with limited n_bytes.
151+
152+
For a shape of (8,6) with n_bytes=20, the optimal chunk is expected to be (8,2).
153+
"""
154+
arr = np.zeros((8, 6), dtype=np.int8)
155+
result = _calculate_optimal_chunksize(arr, 20)
156+
assert result == (8, 2)
157+
158+
159+
def test_three_dim_optimal():
160+
"""Test three-dimensional volume optimal chunk calculation.
161+
162+
For a shape of (9,6,4) with n_bytes=100, the expected chunk is (9,2,4).
163+
"""
164+
arr = np.zeros((9, 6, 4), dtype=np.int8)
165+
result = _calculate_optimal_chunksize(arr, 100)
166+
assert result == (9, 2, 4)
167+
168+
169+
def test_minimal_chunk_for_large_dtype():
170+
"""Test that n_bytes forcing minimal chunking returns all ones.
171+
172+
Using int32 (itemsize=4) with shape (4,5) and n_bytes=4 yields (1,1).
173+
"""
174+
arr = np.zeros((4, 5), dtype=np.int32)
175+
result = _calculate_optimal_chunksize(arr, 4)
176+
assert result == (1, 1)
177+
178+
179+
def test_large_nbytes():
180+
"""Test that a very large n_bytes returns the full volume shape as the optimal chunk."""
181+
arr = np.zeros((10, 10), dtype=np.int8)
182+
result = _calculate_optimal_chunksize(arr, 1000)
183+
assert result == (10, 10)
184+
185+
186+
def test_two_dim_non_int8():
187+
"""Test with a non-int8 dtype where n_bytes exactly covers the full volume in bytes."""
188+
arr = np.zeros((6, 8), dtype=np.int16) # int16 has itemsize 2
189+
# Total bytes of full volume = 6*8*2 = 96, so optimal chunk should be (6,8)
190+
result = _calculate_optimal_chunksize(arr, 96)
191+
assert result == (6, 8)
192+
193+
194+
def test_irregular_dimensions():
195+
"""Test volume with prime dimensions where divisors are limited.
196+
197+
For shape (7,5) with n_bytes=35, optimal chunk should be (7,5) since 7*5 = 35.
198+
"""
199+
arr = np.zeros((7, 5), dtype=np.int8)
200+
result = _calculate_optimal_chunksize(arr, 35)
201+
assert result == (7, 5)
202+
203+
204+
def test_primes():
205+
"""Test volume with prime dimensions where divisors are limited."""
206+
arr = np.zeros((7, 5), dtype=np.int8)
207+
result = _calculate_optimal_chunksize(arr, 23)
208+
assert result == (7, 5)

0 commit comments

Comments
 (0)