Skip to content

Commit 0fd633f

Browse files
committed
feature(chunkgrids): add rectillinear chunk grid metadata support
1 parent 14b372c commit 0fd633f

File tree

2 files changed

+499
-2
lines changed

2 files changed

+499
-2
lines changed

src/zarr/core/chunk_grids.py

Lines changed: 281 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import abstractmethod
99
from dataclasses import dataclass
1010
from functools import reduce
11-
from typing import TYPE_CHECKING, Any, Literal
11+
from typing import TYPE_CHECKING, Any, Literal, TypedDict
1212

1313
import numpy as np
1414

@@ -28,6 +28,107 @@
2828

2929
from zarr.core.array import ShardsLike
3030

31+
from collections.abc import Sequence
32+
33+
# Type alias for chunk edge length specification
34+
# Can be either an integer or a run-length encoded tuple [value, count]
35+
ChunkEdgeLength = int | tuple[int, int]
36+
37+
38+
class RectilinearChunkGridConfigurationDict(TypedDict):
39+
"""TypedDict for rectilinear chunk grid configuration"""
40+
41+
kind: Literal["inline"]
42+
chunk_shapes: Sequence[Sequence[ChunkEdgeLength]]
43+
44+
45+
def _expand_run_length_encoding(spec: Sequence[ChunkEdgeLength]) -> tuple[int, ...]:
46+
"""
47+
Expand a chunk edge length specification into a tuple of integers.
48+
49+
The specification can contain:
50+
- integers: representing explicit edge lengths
51+
- tuples [value, count]: representing run-length encoded sequences
52+
53+
Parameters
54+
----------
55+
spec : Sequence[ChunkEdgeLength]
56+
The chunk edge length specification for one axis
57+
58+
Returns
59+
-------
60+
tuple[int, ...]
61+
Expanded sequence of chunk edge lengths
62+
63+
Examples
64+
--------
65+
>>> _expand_run_length_encoding([2, 3])
66+
(2, 3)
67+
>>> _expand_run_length_encoding([[2, 3]])
68+
(2, 2, 2)
69+
>>> _expand_run_length_encoding([1, [2, 1], 3])
70+
(1, 2, 3)
71+
>>> _expand_run_length_encoding([[1, 3], 3])
72+
(1, 1, 1, 3)
73+
"""
74+
result: list[int] = []
75+
for item in spec:
76+
if isinstance(item, int):
77+
# Explicit edge length
78+
result.append(item)
79+
elif isinstance(item, (list, tuple)):
80+
# Run-length encoded: [value, count]
81+
if len(item) != 2:
82+
raise TypeError(
83+
f"Run-length encoded items must be [int, int], got list of length {len(item)}"
84+
)
85+
value, count = item
86+
# Runtime validation of JSON data
87+
if not isinstance(value, int) or not isinstance(count, int): # type: ignore[redundant-expr]
88+
raise TypeError(
89+
f"Run-length encoded items must be [int, int], got [{type(value).__name__}, {type(count).__name__}]"
90+
)
91+
if count < 0:
92+
raise ValueError(f"Run-length count must be non-negative, got {count}")
93+
result.extend([value] * count)
94+
else:
95+
raise TypeError(
96+
f"Chunk edge length must be int or [int, int] for run-length encoding, got {type(item)}"
97+
)
98+
return tuple(result)
99+
100+
101+
def _parse_chunk_shapes(
102+
data: Sequence[Sequence[ChunkEdgeLength]],
103+
) -> tuple[tuple[int, ...], ...]:
104+
"""
105+
Parse and expand chunk_shapes from metadata.
106+
107+
Parameters
108+
----------
109+
data : Sequence[Sequence[ChunkEdgeLength]]
110+
The chunk_shapes specification from metadata
111+
112+
Returns
113+
-------
114+
tuple[tuple[int, ...], ...]
115+
Tuple of expanded chunk edge lengths for each axis
116+
"""
117+
# Runtime validation - strings are sequences but we don't want them
118+
# Type annotation is for static typing, this validates actual JSON data
119+
if isinstance(data, str) or not isinstance(data, Sequence): # type: ignore[redundant-expr,unreachable]
120+
raise TypeError(f"chunk_shapes must be a sequence, got {type(data)}")
121+
122+
result = []
123+
for i, axis_spec in enumerate(data):
124+
# Runtime validation for each axis spec
125+
if isinstance(axis_spec, str) or not isinstance(axis_spec, Sequence): # type: ignore[redundant-expr,unreachable]
126+
raise TypeError(f"chunk_shapes[{i}] must be a sequence, got {type(axis_spec)}")
127+
expanded = _expand_run_length_encoding(axis_spec)
128+
result.append(expanded)
129+
130+
return tuple(result)
131+
31132

32133
def _guess_chunks(
33134
shape: tuple[int, ...] | int,
@@ -159,6 +260,8 @@ def from_dict(cls, data: dict[str, JSON] | ChunkGrid) -> ChunkGrid:
159260
name_parsed, _ = parse_named_configuration(data)
160261
if name_parsed == "regular":
161262
return RegularChunkGrid._from_dict(data)
263+
elif name_parsed == "rectilinear":
264+
return RectilinearChunkGrid._from_dict(data)
162265
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.")
163266

164267
@abstractmethod
@@ -201,6 +304,183 @@ def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
201304
)
202305

203306

307+
@dataclass(frozen=True)
308+
class RectilinearChunkGrid(ChunkGrid):
309+
"""
310+
A rectilinear chunk grid where chunk sizes vary along each axis.
311+
312+
Attributes
313+
----------
314+
chunk_shapes : tuple[tuple[int, ...], ...]
315+
For each axis, a tuple of chunk edge lengths along that axis.
316+
The sum of edge lengths must equal the array shape along that axis.
317+
"""
318+
319+
chunk_shapes: tuple[tuple[int, ...], ...]
320+
321+
def __init__(self, *, chunk_shapes: Sequence[Sequence[int]]) -> None:
322+
"""
323+
Initialize a RectilinearChunkGrid.
324+
325+
Parameters
326+
----------
327+
chunk_shapes : Sequence[Sequence[int]]
328+
For each axis, a sequence of chunk edge lengths.
329+
"""
330+
# Convert to nested tuples and validate
331+
parsed_shapes: list[tuple[int, ...]] = []
332+
for i, axis_chunks in enumerate(chunk_shapes):
333+
if not isinstance(axis_chunks, Sequence):
334+
raise TypeError(f"chunk_shapes[{i}] must be a sequence, got {type(axis_chunks)}")
335+
# Validate all are positive integers
336+
axis_tuple = tuple(axis_chunks)
337+
for j, size in enumerate(axis_tuple):
338+
if not isinstance(size, int):
339+
raise TypeError(
340+
f"chunk_shapes[{i}][{j}] must be an int, got {type(size).__name__}"
341+
)
342+
if size <= 0:
343+
raise ValueError(f"chunk_shapes[{i}][{j}] must be positive, got {size}")
344+
parsed_shapes.append(axis_tuple)
345+
346+
object.__setattr__(self, "chunk_shapes", tuple(parsed_shapes))
347+
348+
@classmethod
349+
def _from_dict(cls, data: dict[str, JSON]) -> Self:
350+
"""
351+
Parse a RectilinearChunkGrid from metadata dict.
352+
353+
Parameters
354+
----------
355+
data : dict[str, JSON]
356+
Metadata dictionary with 'name' and 'configuration' keys
357+
358+
Returns
359+
-------
360+
Self
361+
A RectilinearChunkGrid instance
362+
"""
363+
_, configuration = parse_named_configuration(data, "rectilinear")
364+
365+
if not isinstance(configuration, dict):
366+
raise TypeError(f"configuration must be a dict, got {type(configuration)}")
367+
368+
# Validate kind field
369+
kind = configuration.get("kind")
370+
if kind != "inline":
371+
raise ValueError(f"Only 'inline' kind is supported, got {kind!r}")
372+
373+
# Parse chunk_shapes with run-length encoding support
374+
chunk_shapes_raw = configuration.get("chunk_shapes")
375+
if chunk_shapes_raw is None:
376+
raise ValueError("configuration must contain 'chunk_shapes'")
377+
378+
# Type ignore: JSON data validated at runtime by _parse_chunk_shapes
379+
chunk_shapes_expanded = _parse_chunk_shapes(chunk_shapes_raw) # type: ignore[arg-type]
380+
381+
return cls(chunk_shapes=chunk_shapes_expanded)
382+
383+
def to_dict(self) -> dict[str, JSON]:
384+
"""
385+
Convert to metadata dict format.
386+
387+
Returns
388+
-------
389+
dict[str, JSON]
390+
Metadata dictionary with 'name' and 'configuration' keys
391+
"""
392+
# Convert to list for JSON serialization
393+
chunk_shapes_list = [list(axis_chunks) for axis_chunks in self.chunk_shapes]
394+
395+
return {
396+
"name": "rectilinear",
397+
"configuration": {
398+
"kind": "inline",
399+
"chunk_shapes": chunk_shapes_list,
400+
},
401+
}
402+
403+
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
404+
"""
405+
Generate all chunk coordinates for the given array shape.
406+
407+
Parameters
408+
----------
409+
array_shape : tuple[int, ...]
410+
Shape of the array
411+
412+
Yields
413+
------
414+
tuple[int, ...]
415+
Chunk coordinates
416+
417+
Raises
418+
------
419+
ValueError
420+
If array_shape doesn't match chunk_shapes
421+
"""
422+
if len(array_shape) != len(self.chunk_shapes):
423+
raise ValueError(
424+
f"array_shape has {len(array_shape)} dimensions but "
425+
f"chunk_shapes has {len(self.chunk_shapes)} dimensions"
426+
)
427+
428+
# Validate that chunk sizes sum to array shape
429+
for axis, (arr_size, axis_chunks) in enumerate(
430+
zip(array_shape, self.chunk_shapes, strict=False)
431+
):
432+
chunk_sum = sum(axis_chunks)
433+
if chunk_sum != arr_size:
434+
raise ValueError(
435+
f"Sum of chunk sizes along axis {axis} is {chunk_sum} "
436+
f"but array shape is {arr_size}"
437+
)
438+
439+
# Generate coordinates
440+
# For each axis, we have len(axis_chunks) chunks
441+
nchunks_per_axis = [len(axis_chunks) for axis_chunks in self.chunk_shapes]
442+
return itertools.product(*(range(n) for n in nchunks_per_axis))
443+
444+
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
445+
"""
446+
Get the total number of chunks for the given array shape.
447+
448+
Parameters
449+
----------
450+
array_shape : tuple[int, ...]
451+
Shape of the array
452+
453+
Returns
454+
-------
455+
int
456+
Total number of chunks
457+
458+
Raises
459+
------
460+
ValueError
461+
If array_shape doesn't match chunk_shapes
462+
"""
463+
if len(array_shape) != len(self.chunk_shapes):
464+
raise ValueError(
465+
f"array_shape has {len(array_shape)} dimensions but "
466+
f"chunk_shapes has {len(self.chunk_shapes)} dimensions"
467+
)
468+
469+
# Validate that chunk sizes sum to array shape
470+
for axis, (arr_size, axis_chunks) in enumerate(
471+
zip(array_shape, self.chunk_shapes, strict=False)
472+
):
473+
chunk_sum = sum(axis_chunks)
474+
if chunk_sum != arr_size:
475+
raise ValueError(
476+
f"Sum of chunk sizes along axis {axis} is {chunk_sum} "
477+
f"but array shape is {arr_size}"
478+
)
479+
480+
# Total chunks is the product of number of chunks per axis
481+
return reduce(operator.mul, (len(axis_chunks) for axis_chunks in self.chunk_shapes), 1)
482+
483+
204484
def _auto_partition(
205485
*,
206486
array_shape: tuple[int, ...],

0 commit comments

Comments
 (0)