diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 1cfc83a97..2190deeec 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -687,6 +687,15 @@ def make_blockwise_key_function( False, ) + for axes, (arg, _) in zip(concat_axes, argpairs): + for ax in axes: + if numblocks[arg][ax] > 1: + raise ValueError( + f"Cannot have multiple chunks in dropped axis {ax}. " + "To fix, use a reduction after calling map_blocks " + "without specifying drop_axis, or rechunk first." + ) + def key_function(out_key): out_coords = out_key[1:] diff --git a/cubed/tests/primitive/test_blockwise.py b/cubed/tests/primitive/test_blockwise.py index 06eaa624b..c8a01b8e3 100644 --- a/cubed/tests/primitive/test_blockwise.py +++ b/cubed/tests/primitive/test_blockwise.py @@ -266,11 +266,11 @@ def test_make_blockwise_key_function_contract(): func = lambda x: 0 key_fn = make_blockwise_key_function( - func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 2), "y": (2, 2)} + func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 1), "y": (1, 2)} ) graph = make_blockwise_graph( - func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 2), "y": (2, 2)} + func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 1), "y": (1, 2)} ) check_consistent_with_graph(key_fn, graph) @@ -290,10 +290,10 @@ def test_make_blockwise_key_function_contract_0d(): func = lambda x: 0 key_fn = make_blockwise_key_function( - func, "z", "", "x", "ij", numblocks={"x": (2, 2)} + func, "z", "", "x", "ij", numblocks={"x": (1, 1)} ) - graph = make_blockwise_graph(func, "z", "", "x", "ij", numblocks={"x": (2, 2)}) + graph = make_blockwise_graph(func, "z", "", "x", "ij", numblocks={"x": (1, 1)}) check_consistent_with_graph(key_fn, graph) diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index a40702373..715b81b77 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -235,6 +235,28 @@ def func(x, y): assert_array_equal(c.compute(), np.array([[[12, 13]]])) +def test_map_blocks_drop_axis_chunking(spec): + # This tests the case illustrated in https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html + # Unlike Dask, Cubed does not support concatenating chunks, and will fail if the dropped axis has multiple chunks. + + def func(x): + return nxp.sum(x, axis=2) + + an = np.arange(8 * 6 * 2).reshape((8, 6, 2)) + + # single chunk in axis=2 works fine + a = xp.asarray(an, chunks=(5, 4, 2), spec=spec) + b = cubed.map_blocks(func, a, drop_axis=2) + assert_array_equal(b.compute(), np.sum(an, axis=2)) + + # multiple chunks in axis=2 raises + a = xp.asarray(an, chunks=(5, 4, 1), spec=spec) + with pytest.raises( + ValueError, match=r"Cannot have multiple chunks in dropped axis 2." + ): + cubed.map_blocks(func, a, drop_axis=2) + + def test_map_blocks_with_non_cubed_array(spec): a = xp.arange(10, dtype="int64", chunks=(2,), spec=spec) b = np.array([1, 2], dtype="int64") # numpy array will be coerced to cubed