@@ -565,15 +565,18 @@ def compute_broadcast_shape(arrays):
565565 return np .broadcast_shapes (* shapes ) if shapes else None
566566
567567
568- def check_smaller_shape (value , shape , slice_shape ):
568+ def check_smaller_shape (value_shape , shape , slice_shape ):
569569 """Check whether the shape of the value is smaller than the shape of the array.
570570
571571 This follows the NumPy broadcasting rules.
572572 """
573+ # slice_shape must be as long as shape
574+ if len (slice_shape ) != len (shape ):
575+ raise ValueError ("slice_shape must be as long as shape" )
573576 is_smaller_shape = any (
574- s > (1 if i >= len (value . shape ) else value . shape [i ]) for i , s in enumerate (slice_shape )
577+ s > (1 if i >= len (value_shape ) else value_shape [i ]) for i , s in enumerate (slice_shape )
575578 )
576- return len (value . shape ) < len (shape ) or is_smaller_shape
579+ return len (value_shape ) < len (shape ) or is_smaller_shape
577580
578581
579582def _compute_smaller_slice (larger_shape , smaller_shape , larger_slice ):
@@ -1366,7 +1369,7 @@ def slices_eval( # noqa: C901
13661369 getitem: bool, optional
13671370 Indicates whether the expression is being evaluated for a getitem operation or compute().
13681371 Default is False.
1369- _slice: slice, list of slices, optional
1372+ _slice: int, slice, list of slices, optional
13701373 If provided, only the chunks that intersect with this slice
13711374 will be evaluated.
13721375 kwargs: Any, optional
@@ -1487,7 +1490,7 @@ def slices_eval( # noqa: C901
14871490 if value .shape == ():
14881491 chunk_operands [key ] = value [()]
14891492 continue
1490- if check_smaller_shape (value , shape , slice_shape ):
1493+ if check_smaller_shape (value . shape , shape , slice_shape ):
14911494 # We need to fetch the part of the value that broadcasts with the operand
14921495 smaller_slice = compute_smaller_slice (shape , value .shape , slice_ )
14931496 chunk_operands [key ] = value [smaller_slice ]
@@ -1644,7 +1647,7 @@ def slices_eval_getitem(
16441647 The expression or user-defined (udf) to evaluate.
16451648 operands: dict
16461649 A dictionary containing the operands for the expression.
1647- _slice: slice, list of slices, optional
1650+ _slice: int, slice, list of slices, optional
16481651 If provided, this slice will be evaluated.
16491652 kwargs: Any, optional
16501653 Additional keyword arguments that are supported by the :func:`empty` constructor.
@@ -1667,8 +1670,10 @@ def slices_eval_getitem(
16671670 else :
16681671 shape = out .shape
16691672
1670- # Provided the slice, compute the shape of the output array
1671- slice_shape = compute_slice_shape (shape , _slice )
1673+ # compute the shape of the output array, broadcasting-compatible
1674+ _slice = ndindex .ndindex (_slice ).expand (shape ).raw # make sure slice is tuple
1675+ _slice_bcast = tuple (slice (i , i + 1 ) if isinstance (i , int ) else i for i in _slice )
1676+ slice_shape = compute_slice_shape (shape , _slice_bcast )
16721677
16731678 # Get the slice of each operand
16741679 slice_operands = {}
@@ -1679,7 +1684,7 @@ def slices_eval_getitem(
16791684 if value .shape == ():
16801685 slice_operands [key ] = value [()]
16811686 continue
1682- if check_smaller_shape (value , shape , slice_shape ):
1687+ if check_smaller_shape (value . shape , shape , slice_shape ):
16831688 # We need to fetch the part of the value that broadcasts with the operand
16841689 smaller_slice = compute_smaller_slice (shape , value .shape , _slice )
16851690 slice_operands [key ] = value [smaller_slice ]
@@ -1736,7 +1741,7 @@ def reduce_slices( # noqa: C901
17361741 A dictionary containing the operands for the operands.
17371742 reduce_args: dict
17381743 A dictionary with arguments to be passed to the reduction function.
1739- _slice: slice, list of slices, optional
1744+ _slice: int, slice, list of slices, optional
17401745 If provided, only the chunks that intersect with this slice
17411746 will be evaluated.
17421747 kwargs: Any, optional
@@ -1828,6 +1833,7 @@ def reduce_slices( # noqa: C901
18281833 reduced_slice = tuple (sl for i , sl in enumerate (slice_ ) if i not in axis )
18291834 offset = tuple (s .start for s in slice_ ) # offset for the udf
18301835 # Check whether current slice_ intersects with _slice
1836+ # TODO: Is this necessary, shouldn't slice always be None for a reduction?
18311837 if _slice is not None and _slice != ():
18321838 # Ensure that slices do not have any None as start or stop
18331839 _slice = tuple (slice (s .start or 0 , s .stop or shape [i ], s .step ) for i , s in enumerate (_slice ))
@@ -1860,7 +1866,7 @@ def reduce_slices( # noqa: C901
18601866 if value .shape == ():
18611867 chunk_operands [key ] = value [()]
18621868 continue
1863- if check_smaller_shape (value , shape , chunks_ ):
1869+ if check_smaller_shape (value . shape , shape , chunks_ ):
18641870 # We need to fetch the part of the value that broadcasts with the operand
18651871 smaller_slice = compute_smaller_slice (operand .shape , value .shape , slice_ )
18661872 chunk_operands [key ] = value [smaller_slice ]
0 commit comments