Skip to content

Commit 79ecee9

Browse files
committed
better approach
1 parent cef873f commit 79ecee9

File tree

2 files changed

+102
-12
lines changed

2 files changed

+102
-12
lines changed

src/zarr/testing/strategies.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,78 @@ def chunk_paths(draw: st.DrawFn, ndim: int, numblocks: tuple[int, ...], subset:
562562
)
563563
subset_slicer = slice(draw(st.integers(min_value=0, max_value=ndim))) if subset else slice(None)
564564
return "/".join(map(str, blockidx[subset_slicer]))
565+
566+
567+
@st.composite
568+
def complex_chunk_grids(draw: st.DrawFn) -> RectilinearChunkGrid:
569+
ndim = draw(st.integers(min_value=1, max_value=3))
570+
nchunks = draw(st.integers(min_value=10, max_value=100))
571+
dim_chunks = st.lists(
572+
st.integers(min_value=1, max_value=10), unique=True, min_size=nchunks, max_size=nchunks
573+
)
574+
if draw(st.booleans()):
575+
event("using RectilinearChunkGrid")
576+
chunk_shapes = draw(st.lists(dim_chunks, min_size=ndim, max_size=ndim))
577+
return RectilinearChunkGrid(chunk_shapes=chunk_shapes)
578+
579+
else:
580+
event("using RectilinearChunkGrid (run length encoded)")
581+
repeats = st.lists(
582+
st.integers(min_value=1, max_value=20), min_size=nchunks, max_size=nchunks
583+
)
584+
chunk_shapes_rle = [
585+
[[c, r] for c, r in zip(draw(dim_chunks), draw(repeats), strict=True)]
586+
for _ in range(ndim)
587+
]
588+
return RectilinearChunkGrid(chunk_shapes=chunk_shapes_rle)
589+
590+
591+
@st.composite
592+
def complex_chunked_arrays(
593+
draw: st.DrawFn,
594+
*,
595+
stores: st.SearchStrategy[StoreLike] = stores,
596+
) -> Array:
597+
store = draw(stores, label="store")
598+
chunks = draw(complex_chunk_grids(), label="chunk grid")
599+
assert isinstance(chunks, RectilinearChunkGrid)
600+
shape = tuple(x[-1] for x in chunks._cumulative_sizes)
601+
nparray = draw(numpy_arrays(shapes=st.just(shape)), label="array data")
602+
root = zarr.open_group(store, mode="w")
603+
604+
a = root.create_array(
605+
"/foo",
606+
shape=nparray.shape,
607+
chunks=chunks,
608+
shards=None,
609+
dtype=nparray.dtype,
610+
attributes={},
611+
fill_value=None,
612+
dimension_names=None,
613+
)
614+
615+
assert isinstance(a, Array)
616+
if a.metadata.zarr_format == 3:
617+
assert a.fill_value is not None
618+
assert nparray.shape == a.shape
619+
620+
# Verify chunks - for RegularChunkGrid check exact match
621+
# For RectilinearChunkGrid, skip chunks check since it raises NotImplementedError
622+
if isinstance(a.metadata.chunk_grid, RectilinearChunkGrid):
623+
# Just verify the chunk_grid is set correctly
624+
assert isinstance(a.metadata.chunk_grid, RectilinearChunkGrid)
625+
# shards also raises NotImplementedError for RectilinearChunkGrid
626+
else:
627+
# For RegularChunkGrid, the chunks property returns the normalized chunk_shape
628+
# which may differ from the input (e.g., (0,) becomes (1,) after normalization)
629+
# We should compare against the actual chunk_grid's chunk_shape
630+
from zarr.core.chunk_grids import RegularChunkGrid
631+
632+
assert isinstance(a.metadata.chunk_grid, RegularChunkGrid)
633+
expected_chunks = a.metadata.chunk_grid.chunk_shape
634+
assert expected_chunks == a.chunks
635+
636+
assert a.shards is None # We don't use sharding with RectilinearChunkGrid
637+
638+
a[:] = nparray
639+
return a

tests/test_properties.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import hypothesis.strategies as st
1515
from hypothesis import assume, given, settings
1616

17+
from zarr import Array
1718
from zarr.abc.store import Store
1819
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
1920
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
@@ -22,6 +23,7 @@
2223
array_metadata,
2324
arrays,
2425
basic_indices,
26+
complex_chunked_arrays,
2527
numpy_arrays,
2628
orthogonal_indices,
2729
simple_arrays,
@@ -106,11 +108,10 @@ def test_array_creates_implicit_groups(array):
106108

107109

108110
@pytest.mark.asyncio
109-
@settings(deadline=None)
111+
@settings(deadline=None, report_multiple_bugs=False)
110112
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
111-
@given(data=st.data())
112-
async def test_basic_indexing(data: st.DataObject) -> None:
113-
zarray = data.draw(simple_arrays())
113+
@given(data=st.data(), zarray=st.one_of([simple_arrays(), complex_chunked_arrays()]))
114+
async def test_basic_indexing(data: st.DataObject, zarray: Array) -> None:
114115
nparray = zarray[:]
115116
indexer = data.draw(basic_indices(shape=nparray.shape))
116117

@@ -133,11 +134,18 @@ async def test_basic_indexing(data: st.DataObject) -> None:
133134

134135

135136
@pytest.mark.asyncio
136-
@given(data=st.data())
137+
@given(
138+
data=st.data(),
139+
zarray=st.one_of(
140+
[
141+
# integer_array_indices can't handle 0-size dimensions.
142+
simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)),
143+
complex_chunked_arrays(),
144+
]
145+
),
146+
)
137147
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
138-
async def test_oindex(data: st.DataObject) -> None:
139-
# integer_array_indices can't handle 0-size dimensions.
140-
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
148+
async def test_oindex(data: st.DataObject, zarray: Array) -> None:
141149
nparray = zarray[:]
142150
zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
143151

@@ -165,11 +173,18 @@ async def test_oindex(data: st.DataObject) -> None:
165173

166174

167175
@pytest.mark.asyncio
168-
@given(data=st.data())
176+
@given(
177+
data=st.data(),
178+
zarray=st.one_of(
179+
[
180+
# integer_array_indices can't handle 0-size dimensions.
181+
simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)),
182+
complex_chunked_arrays(),
183+
]
184+
),
185+
)
169186
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
170-
async def test_vindex(data: st.DataObject) -> None:
171-
# integer_array_indices can't handle 0-size dimensions.
172-
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
187+
async def test_vindex(data: st.DataObject, zarray: Array) -> None:
173188
nparray = zarray[:]
174189
indexer = data.draw(
175190
npst.integer_array_indices(

0 commit comments

Comments
 (0)