Skip to content

Commit b5c0a6d

Browse files
author
Luke Shaw
committed
Rename result and result_idx variables for clarity
1 parent 558f363 commit b5c0a6d

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

src/blosc2/lazyexpr.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,7 +1799,8 @@ def reduce_slices( # noqa: C901
17991799
add_idx = np.cumsum(mask_slice)
18001800
axis = tuple(a + add_idx[a] for a in axis) # axis now refers to new shape with dummy dims
18011801
if reduce_args["axis"] is not None:
1802-
reduce_args["axis"] = axis
1802+
# conserve as integer if was not tuple originally
1803+
reduce_args["axis"] = axis[0] if np.isscalar(reduce_args["axis"]) else axis
18031804
if keepdims:
18041805
reduced_shape = tuple(1 if i in axis else s for i, s in enumerate(shape_slice))
18051806
else:
@@ -1875,7 +1876,7 @@ def reduce_slices( # noqa: C901
18751876
cslice = step_handler(cslice, _slice)
18761877
chunks_ = tuple(s.stop - s.start for s in cslice)
18771878
unit_steps = np.all([s.step == 1 for s in cslice])
1878-
# Get the starts for the slice (needed for offset calculations when intra-chunk slicing)
1879+
# Starts for slice
18791880
starts = [s.start if s.start is not None else 0 for s in cslice]
18801881
if _slice == () and fast_path and unit_steps:
18811882
# Fast path
@@ -1961,23 +1962,34 @@ def reduce_slices( # noqa: C901
19611962
elif reduce_op == ReduceOp.ALL:
19621963
result = np.all(result, **reduce_args)
19631964
elif reduce_op == ReduceOp.ARGMAX or reduce_op == ReduceOp.ARGMIN:
1964-
result_val = result
1965-
result = (
1965+
# offset for start of slice
1966+
slice_ref = (
1967+
starts
1968+
if _slice == ()
1969+
else [
1970+
(s - sl.start - np.sign(sl.step)) // sl.step + 1
1971+
for s, sl in zip(starts, _slice, strict=True)
1972+
]
1973+
)
1974+
result_idx = (
19661975
np.argmin(result, **reduce_args)
19671976
if reduce_op == ReduceOp.ARGMIN
19681977
else np.argmax(result, **reduce_args)
19691978
)
19701979
if reduce_args["axis"] is None: # indexing into flattened array
1971-
result_val = result_val[np.unravel_index(result, shape=result_val.shape)]
1972-
idx_within_cslice = np.unravel_index(result, shape=chunks_)
1973-
result = np.ravel_multi_index(
1974-
tuple(o + i for o, i in zip(starts, idx_within_cslice, strict=True)), shape
1980+
result = result[np.unravel_index(result_idx, shape=result.shape)]
1981+
idx_within_cslice = np.unravel_index(result_idx, shape=chunks_)
1982+
result_idx = np.ravel_multi_index(
1983+
tuple(o + i for o, i in zip(slice_ref, idx_within_cslice, strict=True)), shape_slice
19751984
)
19761985
else: # axis is an integer
1977-
result_val = np.take_along_axis(
1978-
result_val, np.expand_dims(result, axis=reduce_args["axis"]), axis=reduce_args["axis"]
1986+
result = np.take_along_axis(
1987+
result,
1988+
np.expand_dims(result_idx, axis=reduce_args["axis"]) if not keepdims else result_idx,
1989+
axis=reduce_args["axis"],
19791990
)
1980-
result += starts[reduce_args["axis"]]
1991+
result = result if keepdims else result.squeeze(axis=reduce_args["axis"])
1992+
result_idx += slice_ref[reduce_args["axis"]]
19811993
else:
19821994
result = reduce_op.value.reduce(result, **reduce_args)
19831995

@@ -1997,17 +2009,17 @@ def reduce_slices( # noqa: C901
19972009
out[reduced_slice] *= result
19982010
elif res_out_ is not None: # i.e. ReduceOp.ARGMAX or ReduceOp.ARGMIN
19992011
# need lowest index for which optimum attained
2000-
cond = (res_out_[reduced_slice] == result_val) & (result < out[reduced_slice])
2012+
cond = (res_out_[reduced_slice] == result) & (result_idx < out[reduced_slice])
20012013
if reduce_op == ReduceOp.ARGMAX:
2002-
cond |= res_out_[reduced_slice] < result_val
2014+
cond |= res_out_[reduced_slice] < result
20032015
else: # ARGMIN
2004-
cond |= res_out_[reduced_slice] > result_val
2016+
cond |= res_out_[reduced_slice] > result
20052017
if reduced_slice == ():
2006-
out = np.where(cond, result, out[reduced_slice])
2007-
res_out_ = np.where(cond, result_val, res_out_[reduced_slice])
2018+
out = np.where(cond, result_idx, out[reduced_slice])
2019+
res_out_ = np.where(cond, result, res_out_[reduced_slice])
20082020
else:
2009-
out[reduced_slice] = np.where(cond, result, out[reduced_slice])
2010-
res_out_[reduced_slice] = np.where(cond, result_val, res_out_[reduced_slice])
2021+
out[reduced_slice] = np.where(cond, result_idx, out[reduced_slice])
2022+
res_out_[reduced_slice] = np.where(cond, result, res_out_[reduced_slice])
20112023
else:
20122024
if reduced_slice == ():
20132025
out = reduce_op.value(out, result)

tests/ndarray/test_reductions.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_fast_path(chunks, blocks, disk, fill_value, reduce_op, axis):
258258
na = a[:]
259259

260260
res = getattr(a, reduce_op)(axis=axis)
261-
nres = getattr(na[:], reduce_op)(axis=axis)
261+
nres = getattr(na, reduce_op)(axis=axis)
262262

263263
assert np.allclose(res, nres)
264264

@@ -270,9 +270,10 @@ def test_fast_path(chunks, blocks, disk, fill_value, reduce_op, axis):
270270
)
271271
@pytest.mark.parametrize("axis", [0, (0, 1), None])
272272
def test_save_version1(disk, fill_value, reduce_op, axis):
273+
shape = (20, 50, 100)
273274
if isinstance(axis, tuple) and (reduce_op in ("argmax", "argmin")):
274275
axis = 1
275-
shape = (20, 50, 100)
276+
shape = (20, 20, 100)
276277
urlpath = "a1.b2nd" if disk else None
277278
if fill_value != 0:
278279
a = blosc2.full(shape, fill_value, urlpath=urlpath, mode="w")
@@ -310,9 +311,10 @@ def test_save_version1(disk, fill_value, reduce_op, axis):
310311
)
311312
@pytest.mark.parametrize("axis", [0, (0, 1), None])
312313
def test_save_version2(disk, fill_value, reduce_op, axis):
314+
shape = (20, 50, 100)
313315
if isinstance(axis, tuple) and (reduce_op in ("argmax", "argmin")):
314316
axis = 1
315-
shape = (20, 50, 100)
317+
shape = (20, 20, 100)
316318
urlpath = "a1.b2nd" if disk else None
317319
if fill_value != 0:
318320
a = blosc2.full(shape, fill_value, urlpath=urlpath, mode="w")
@@ -349,9 +351,10 @@ def test_save_version2(disk, fill_value, reduce_op, axis):
349351
)
350352
@pytest.mark.parametrize("axis", [0, (0, 1), None])
351353
def test_save_version3(disk, fill_value, reduce_op, axis):
354+
shape = (20, 50, 100)
352355
if isinstance(axis, tuple) and (reduce_op in ("argmax", "argmin")):
353356
axis = 1
354-
shape = (20, 50, 100)
357+
shape = (20, 20, 100)
355358
urlpath = "a1.b2nd" if disk else None
356359
if fill_value != 0:
357360
a = blosc2.full(shape, fill_value, urlpath=urlpath, mode="w")

0 commit comments

Comments
 (0)