Skip to content

Commit 5e80cd3

Browse files
author
Luke Shaw
committed
Small fix for failing indexing test case plus comments
1 parent 5359bfc commit 5e80cd3

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

src/blosc2/lazyexpr.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -565,15 +565,18 @@ def compute_broadcast_shape(arrays):
565565
return np.broadcast_shapes(*shapes) if shapes else None
566566

567567

568-
def check_smaller_shape(value, shape, slice_shape):
568+
def check_smaller_shape(value_shape, shape, slice_shape):
569569
"""Check whether the shape of the value is smaller than the shape of the array.
570570
571571
This follows the NumPy broadcasting rules.
572572
"""
573+
# slice_shape must be as long as shape
574+
if len(slice_shape) != len(shape):
575+
raise ValueError("slice_shape must be as long as shape")
573576
is_smaller_shape = any(
574-
s > (1 if i >= len(value.shape) else value.shape[i]) for i, s in enumerate(slice_shape)
577+
s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_shape)
575578
)
576-
return len(value.shape) < len(shape) or is_smaller_shape
579+
return len(value_shape) < len(shape) or is_smaller_shape
577580

578581

579582
def _compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
@@ -1366,7 +1369,7 @@ def slices_eval( # noqa: C901
13661369
getitem: bool, optional
13671370
Indicates whether the expression is being evaluated for a getitem operation or compute().
13681371
Default is False.
1369-
_slice: slice, list of slices, optional
1372+
_slice: int, slice, list of slices, optional
13701373
If provided, only the chunks that intersect with this slice
13711374
will be evaluated.
13721375
kwargs: Any, optional
@@ -1487,7 +1490,7 @@ def slices_eval( # noqa: C901
14871490
if value.shape == ():
14881491
chunk_operands[key] = value[()]
14891492
continue
1490-
if check_smaller_shape(value, shape, slice_shape):
1493+
if check_smaller_shape(value.shape, shape, slice_shape):
14911494
# We need to fetch the part of the value that broadcasts with the operand
14921495
smaller_slice = compute_smaller_slice(shape, value.shape, slice_)
14931496
chunk_operands[key] = value[smaller_slice]
@@ -1644,7 +1647,7 @@ def slices_eval_getitem(
16441647
The expression or user-defined (udf) to evaluate.
16451648
operands: dict
16461649
A dictionary containing the operands for the expression.
1647-
_slice: slice, list of slices, optional
1650+
_slice: int, slice, list of slices, optional
16481651
If provided, this slice will be evaluated.
16491652
kwargs: Any, optional
16501653
Additional keyword arguments that are supported by the :func:`empty` constructor.
@@ -1667,8 +1670,10 @@ def slices_eval_getitem(
16671670
else:
16681671
shape = out.shape
16691672

1670-
# Provided the slice, compute the shape of the output array
1671-
slice_shape = compute_slice_shape(shape, _slice)
1673+
# compute the shape of the output array, broadcasting-compatible
1674+
_slice = ndindex.ndindex(_slice).expand(shape).raw # make sure slice is tuple
1675+
_slice_bcast = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in _slice)
1676+
slice_shape = compute_slice_shape(shape, _slice_bcast)
16721677

16731678
# Get the slice of each operand
16741679
slice_operands = {}
@@ -1679,7 +1684,7 @@ def slices_eval_getitem(
16791684
if value.shape == ():
16801685
slice_operands[key] = value[()]
16811686
continue
1682-
if check_smaller_shape(value, shape, slice_shape):
1687+
if check_smaller_shape(value.shape, shape, slice_shape):
16831688
# We need to fetch the part of the value that broadcasts with the operand
16841689
smaller_slice = compute_smaller_slice(shape, value.shape, _slice)
16851690
slice_operands[key] = value[smaller_slice]
@@ -1736,7 +1741,7 @@ def reduce_slices( # noqa: C901
17361741
A dictionary containing the operands for the operands.
17371742
reduce_args: dict
17381743
A dictionary with arguments to be passed to the reduction function.
1739-
_slice: slice, list of slices, optional
1744+
_slice: int, slice, list of slices, optional
17401745
If provided, only the chunks that intersect with this slice
17411746
will be evaluated.
17421747
kwargs: Any, optional
@@ -1828,6 +1833,7 @@ def reduce_slices( # noqa: C901
18281833
reduced_slice = tuple(sl for i, sl in enumerate(slice_) if i not in axis)
18291834
offset = tuple(s.start for s in slice_) # offset for the udf
18301835
# Check whether current slice_ intersects with _slice
1836+
# TODO: Is this necessary, shouldn't slice always be None for a reduction?
18311837
if _slice is not None and _slice != ():
18321838
# Ensure that slices do not have any None as start or stop
18331839
_slice = tuple(slice(s.start or 0, s.stop or shape[i], s.step) for i, s in enumerate(_slice))
@@ -1860,7 +1866,7 @@ def reduce_slices( # noqa: C901
18601866
if value.shape == ():
18611867
chunk_operands[key] = value[()]
18621868
continue
1863-
if check_smaller_shape(value, shape, chunks_):
1869+
if check_smaller_shape(value.shape, shape, chunks_):
18641870
# We need to fetch the part of the value that broadcasts with the operand
18651871
smaller_slice = compute_smaller_slice(operand.shape, value.shape, slice_)
18661872
chunk_operands[key] = value[smaller_slice]

tests/ndarray/test_lazyexpr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,15 @@ def test_eval_getitem(array_fixture):
10761076
np.testing.assert_allclose(expr[:10], nres[:10])
10771077
np.testing.assert_allclose(expr[0:10:2], nres[0:10:2])
10781078

1079+
# Small test
1080+
shape = (2, 10, 5)
1081+
test_arr = blosc2.linspace(0, 10, np.prod(shape), shape=shape)
1082+
expr = test_arr * 30
1083+
nres = test_arr[:] * 30
1084+
np.testing.assert_allclose(expr[0], nres[0])
1085+
np.testing.assert_allclose(expr[:10], nres[:10])
1086+
np.testing.assert_allclose(expr[0:10:2], nres[0:10:2])
1087+
10791088

10801089
# Test lazyexpr's slice method
10811090
def test_eval_slice(array_fixture):

0 commit comments

Comments
 (0)