@@ -1799,7 +1799,8 @@ def reduce_slices( # noqa: C901
17991799 add_idx = np .cumsum (mask_slice )
18001800 axis = tuple (a + add_idx [a ] for a in axis ) # axis now refers to new shape with dummy dims
18011801 if reduce_args ["axis" ] is not None :
1802- reduce_args ["axis" ] = axis
1802+ # conserve as integer if was not tuple originally
1803+ reduce_args ["axis" ] = axis [0 ] if np .isscalar (reduce_args ["axis" ]) else axis
18031804 if keepdims :
18041805 reduced_shape = tuple (1 if i in axis else s for i , s in enumerate (shape_slice ))
18051806 else :
@@ -1875,7 +1876,7 @@ def reduce_slices( # noqa: C901
18751876 cslice = step_handler (cslice , _slice )
18761877 chunks_ = tuple (s .stop - s .start for s in cslice )
18771878 unit_steps = np .all ([s .step == 1 for s in cslice ])
1878- # Get the starts for the slice (needed for offset calculations when intra-chunk slicing)
1879+ # Starts for slice
18791880 starts = [s .start if s .start is not None else 0 for s in cslice ]
18801881 if _slice == () and fast_path and unit_steps :
18811882 # Fast path
@@ -1961,23 +1962,34 @@ def reduce_slices( # noqa: C901
19611962 elif reduce_op == ReduceOp .ALL :
19621963 result = np .all (result , ** reduce_args )
19631964 elif reduce_op == ReduceOp .ARGMAX or reduce_op == ReduceOp .ARGMIN :
1964- result_val = result
1965- result = (
1965+ # offset for start of slice
1966+ slice_ref = (
1967+ starts
1968+ if _slice == ()
1969+ else [
1970+ (s - sl .start - np .sign (sl .step )) // sl .step + 1
1971+ for s , sl in zip (starts , _slice , strict = True )
1972+ ]
1973+ )
1974+ result_idx = (
19661975 np .argmin (result , ** reduce_args )
19671976 if reduce_op == ReduceOp .ARGMIN
19681977 else np .argmax (result , ** reduce_args )
19691978 )
19701979 if reduce_args ["axis" ] is None : # indexing into flattened array
1971- result_val = result_val [np .unravel_index (result , shape = result_val .shape )]
1972- idx_within_cslice = np .unravel_index (result , shape = chunks_ )
1973- result = np .ravel_multi_index (
1974- tuple (o + i for o , i in zip (starts , idx_within_cslice , strict = True )), shape
1980+ result = result [np .unravel_index (result_idx , shape = result .shape )]
1981+ idx_within_cslice = np .unravel_index (result_idx , shape = chunks_ )
1982+ result_idx = np .ravel_multi_index (
1983+ tuple (o + i for o , i in zip (slice_ref , idx_within_cslice , strict = True )), shape_slice
19751984 )
19761985 else : # axis is an integer
1977- result_val = np .take_along_axis (
1978- result_val , np .expand_dims (result , axis = reduce_args ["axis" ]), axis = reduce_args ["axis" ]
1986+ result = np .take_along_axis (
1987+ result ,
1988+ np .expand_dims (result_idx , axis = reduce_args ["axis" ]) if not keepdims else result_idx ,
1989+ axis = reduce_args ["axis" ],
19791990 )
1980- result += starts [reduce_args ["axis" ]]
1991+ result = result if keepdims else result .squeeze (axis = reduce_args ["axis" ])
1992+ result_idx += slice_ref [reduce_args ["axis" ]]
19811993 else :
19821994 result = reduce_op .value .reduce (result , ** reduce_args )
19831995
@@ -1997,17 +2009,17 @@ def reduce_slices( # noqa: C901
19972009 out [reduced_slice ] *= result
19982010 elif res_out_ is not None : # i.e. ReduceOp.ARGMAX or ReduceOp.ARGMIN
19992011 # need lowest index for which optimum attained
2000- cond = (res_out_ [reduced_slice ] == result_val ) & (result < out [reduced_slice ])
2012+ cond = (res_out_ [reduced_slice ] == result ) & (result_idx < out [reduced_slice ])
20012013 if reduce_op == ReduceOp .ARGMAX :
2002- cond |= res_out_ [reduced_slice ] < result_val
2014+ cond |= res_out_ [reduced_slice ] < result
20032015 else : # ARGMIN
2004- cond |= res_out_ [reduced_slice ] > result_val
2016+ cond |= res_out_ [reduced_slice ] > result
20052017 if reduced_slice == ():
2006- out = np .where (cond , result , out [reduced_slice ])
2007- res_out_ = np .where (cond , result_val , res_out_ [reduced_slice ])
2018+ out = np .where (cond , result_idx , out [reduced_slice ])
2019+ res_out_ = np .where (cond , result , res_out_ [reduced_slice ])
20082020 else :
2009- out [reduced_slice ] = np .where (cond , result , out [reduced_slice ])
2010- res_out_ [reduced_slice ] = np .where (cond , result_val , res_out_ [reduced_slice ])
2021+ out [reduced_slice ] = np .where (cond , result_idx , out [reduced_slice ])
2022+ res_out_ [reduced_slice ] = np .where (cond , result , res_out_ [reduced_slice ])
20112023 else :
20122024 if reduced_slice == ():
20132025 out = reduce_op .value (out , result )
0 commit comments