Skip to content

Commit d0a0bdf

Browse files
committed
Cleanup, mocking, and relocation of autochunking
1 parent 78fe55d commit d0a0bdf

File tree

3 files changed

+50
-160
lines changed

3 files changed

+50
-160
lines changed

src/mdio/converters/segy.py

Lines changed: 1 addition & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,17 @@
1717
from segy.schema import HeaderField
1818

1919
from mdio.api.io_utils import process_url
20-
from mdio.constants import INT32_MAX
2120
from mdio.converters.exceptions import EnvironmentFormatError
2221
from mdio.converters.exceptions import GridTraceCountError
2322
from mdio.converters.exceptions import GridTraceSparsityError
2423
from mdio.core import Grid
24+
from mdio.core.grid import _calculate_live_mask_chunksize
2525
from mdio.core.utils_write import write_attribute
2626
from mdio.segy import blocked_io
2727
from mdio.segy.compat import mdio_segy_spec
2828
from mdio.segy.helpers_segy import create_zarr_hierarchy
2929
from mdio.segy.utilities import get_grid_plan
3030

31-
from dask.array.core import normalize_chunks
32-
from dask.array.rechunk import _balance_chunksizes
33-
3431
logger = logging.getLogger(__name__)
3532

3633
try:
@@ -499,153 +496,3 @@ def segy_to_mdio( # noqa: C901
499496
)
500497

501498
zarr.consolidate_metadata(store_nocache)
502-
503-
504-
def _calculate_live_mask_chunksize(grid: Grid) -> Sequence[int]:
505-
"""Calculate the optimal chunksize for the live mask.
506-
507-
Args:
508-
grid: The grid to calculate the chunksize for.
509-
510-
Returns:
511-
A sequence of integers representing the optimal chunk size for each dimension
512-
of the grid.
513-
"""
514-
return _calculate_optimal_chunksize(grid.live_mask, INT32_MAX//4)
515-
516-
517-
def _calculate_optimal_chunksize( # noqa: C901
518-
volume: np.ndarray | zarr.Array, n_bytes: int
519-
) -> Sequence[int]:
520-
"""Calculate the optimal chunksize for an N-dimensional data volume.
521-
522-
Args:
523-
volume: The volume to calculate the chunksize for.
524-
n_bytes: The maximum allowed number of bytes per chunk.
525-
526-
Returns:
527-
A sequence of integers representing the optimal chunk size for each dimension
528-
of the grid.
529-
"""
530-
shape = volume.shape
531-
chunks = normalize_chunks(
532-
"auto",
533-
shape,
534-
dtype=volume.dtype,
535-
limit=n_bytes,
536-
)
537-
return tuple(_balance_chunksizes(chunk)[0] for chunk in chunks)
538-
539-
540-
541-
# 0. The product of the chunk dimensions multiplied by the element size does not
542-
# exceed n_bytes.
543-
# 1. The chunk shape is "regular" – each chunk dimension is a divisor of the
544-
# overall volume shape.
545-
# 2. If an exact match is impossible, the chunk shape chosen maximizes the number of
546-
# elements (minimizing the unused bytes).
547-
# 3. The computation is efficient.
548-
549-
# The computation efficiency is broken down as follows:
550-
551-
# - Divisor Computation: For each of the N dimensions (assume size ~ n), it checks
552-
# up to n numbers, so this part is roughly O(N * n).
553-
# For example, if you have a 3D array where each dimension is about 100,
554-
# it does around 3*100 = 300 steps.
555-
# - DFS Search: In the worst-case, the DFS explores about D choices per dimension
556-
# (D = average number of divisors) leading to O(D^N) combinations.
557-
# In practice, D is small (often < 10), so for a 2D array this is around 10^2
558-
# (about 100 combinations) and for a 3D array about 10^3 (roughly 1,000 combinations).
559-
# Since N is typically small (often <6), this exponential term behaves like a
560-
# constant factor.
561-
562-
# Args:
563-
# volume : np.ndarray | zarr.Array
564-
# An N-dimensional array-like object (e.g. np.ndarray or zarr array).
565-
# n_bytes : int
566-
# Maximum allowed number of bytes per chunk (>= 1).
567-
568-
# Returns:
569-
# Sequence[int]
570-
# A tuple representing the optimal chunk shape (number of elements along each axis).
571-
572-
# Raises:
573-
# ValueError if n_bytes is less than the number of bytes of one element.
574-
# """
575-
# # Get volume shape and element size.
576-
# shape = volume.shape
577-
578-
# if volume.size == 0:
579-
# logging.warning("Chunking calculation received empty volume shape...")
580-
# return volume.shape
581-
582-
# itemsize = volume.dtype.itemsize
583-
584-
# # Maximum number of elements that can fit in a chunk
585-
# # (we ignore any extra bytes; must not exceed n_bytes).
586-
# max_elements_allowed = n_bytes // itemsize
587-
# if max_elements_allowed < 1:
588-
# raise ValueError("n_bytes is too small to hold even one element of the volume.")
589-
590-
# n_dims = len(shape)
591-
592-
# def get_divisors(n: int) -> list[int]:
593-
# """Return a sorted list of all positive divisors of n.
594-
595-
# Args:
596-
# n: The number to compute the divisors of.
597-
598-
# Returns:
599-
# A sorted list of all positive divisors of n.
600-
# """
601-
# divs = []
602-
# # It is efficient enough for typical dimension sizes.
603-
# for i in range(1, n + 1):
604-
# if n % i == 0:
605-
# divs.append(i)
606-
# return sorted(divs)
607-
608-
# # For each dimension, compute the list of allowed chunk sizes (divisors).
609-
# divisors_list = [get_divisors(d) for d in shape]
610-
611-
# # For pruning: precompute the maximum possible product achievable from axis i to N-1.
612-
# # This is the product of the maximum divisors for each remaining axis.
613-
# max_possible = [1] * (n_dims + 1)
614-
# for i in range(n_dims - 1, -1, -1):
615-
# max_possible[i] = max(divisors_list[i]) * max_possible[i + 1]
616-
617-
# best_product = 0
618-
# best_combination = [None] * n_dims
619-
# current_chunk = [None] * n_dims
620-
621-
# def dfs(dim: int, current_product: int) -> None:
622-
# """Depth-first search to find the optimal chunk shape.
623-
624-
# Args:
625-
# dim: The current dimension to process.
626-
# current_product: The current product of the chunk dimensions.
627-
# """
628-
# nonlocal best_product
629-
# # If all dimensions have been processed, update best combination if needed.
630-
# if dim == n_dims:
631-
# if current_product > best_product:
632-
# best_product = current_product
633-
# best_combination[:] = current_chunk[:]
634-
# return
635-
636-
# # Prune branches: even if we take the maximum allowed for all remaining dimensions,
637-
# # if we cannot exceed best_product, then skip.
638-
# if current_product * max_possible[dim] < best_product:
639-
# return
640-
641-
# # Iterate over allowed divisors for the current axis,
642-
# # trying larger candidates first so that high products are found early.
643-
# for candidate in sorted(divisors_list[dim], reverse=True):
644-
# new_product = current_product * candidate
645-
# if new_product > max_elements_allowed:
646-
# continue # This candidate would exceed the byte restriction.
647-
# current_chunk[dim] = candidate
648-
# dfs(dim + 1, new_product)
649-
650-
# dfs(0, 1)
651-
# return tuple(best_combination)

src/mdio/core/grid.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Sequence
6+
57
import inspect
68
import logging
79
from dataclasses import dataclass
@@ -11,6 +13,11 @@
1113

1214
from mdio.constants import UINT32_MAX
1315
from mdio.constants import UINT64_MAX
16+
from mdio.constants import INT32_MAX
17+
18+
from dask.array.core import normalize_chunks
19+
from dask.array.rechunk import _balance_chunksizes
20+
1421
from mdio.core import Dimension
1522
from mdio.core.serialization import Serializer
1623

@@ -134,3 +141,40 @@ def deserialize(self, stream: str) -> Grid:
134141
payload = self.validate_payload(payload, signature)
135142

136143
return Grid(**payload)
144+
145+
146+
def _calculate_live_mask_chunksize(grid: Grid) -> Sequence[int]:
147+
"""Calculate the optimal chunksize for the live mask.
148+
149+
Args:
150+
grid: The grid to calculate the chunksize for.
151+
152+
Returns:
153+
A sequence of integers representing the optimal chunk size for each dimension
154+
of the grid.
155+
"""
156+
return _calculate_optimal_chunksize(grid.live_mask, INT32_MAX//4)
157+
158+
159+
def _calculate_optimal_chunksize( # noqa: C901
160+
volume: np.ndarray | zarr.Array,
161+
max_bytes: int
162+
) -> Sequence[int]:
163+
"""Calculate the optimal chunksize for an N-dimensional data volume.
164+
165+
Args:
166+
volume: The volume to calculate the chunksize for.
167+
n_bytes: The maximum allowed number of bytes per chunk.
168+
169+
Returns:
170+
A sequence of integers representing the optimal chunk size for each dimension
171+
of the grid.
172+
"""
173+
shape = volume.shape
174+
chunks = normalize_chunks(
175+
"auto",
176+
shape,
177+
dtype=volume.dtype,
178+
limit=max_bytes,
179+
)
180+
return tuple(_balance_chunksizes(chunk)[0] for chunk in chunks)

tests/unit/test_live_mask_chunksize.py renamed to tests/unit/test_auto_chunking.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import numpy as np
44
import pytest
55

6-
from mdio.converters.segy import _calculate_live_mask_chunksize
7-
from mdio.converters.segy import _calculate_optimal_chunksize
86
from mdio.core import Dimension
97
from mdio.core import Grid
8+
from mdio.core.grid import _calculate_live_mask_chunksize
9+
from mdio.core.grid import _calculate_optimal_chunksize
1010

1111

1212
class MockArray:
@@ -222,9 +222,7 @@ def test_altay():
222222
grid = Grid(dims=dims)
223223
grid.live_mask = MockArray(shape, bool)
224224

225-
# Calculate chunk size using the live mask function
226225
result = _calculate_live_mask_chunksize(grid)
227-
print(f"{kind}: {result}")
228226

229227
# Verify that the chunk size is valid
230228
assert all(chunk > 0 for chunk in result), f"Invalid chunk size for {kind}"
@@ -235,5 +233,6 @@ def test_altay():
235233
if kind in ["right_above_2G", "above_2G_v2", "above_2G_v2_asym", "above_4G_v2_asym", "above_3G_4D_asym"]:
236234
# TODO(BrianMichell): Our implementation is taking "limit" pretty liberally.
237235
# This is not overtly an issue because we are well below the 2GiB limit, but it's indicative of an underlying issue.
238-
continue
239-
assert chunk_elements <= INT32_MAX // 4, f"Chunk too large for {kind}"
236+
assert chunk_elements <= (INT32_MAX // 4) * 1.5, f"Chunk too large for {kind}"
237+
else:
238+
assert chunk_elements <= INT32_MAX // 4, f"Chunk too large for {kind}"

0 commit comments

Comments
 (0)