Skip to content

Commit 2aeaa44

Browse files
committed
TST: Add test for reduceat/accumulate output shape mismatch
At least the reduceat path seems to have been untested before. The slight change in code layout (added assert) is just to make the code slightly easier to read. Since otherwise it looks like there is an additional `else` branch when `out` is given.
1 parent b5a76ef commit 2aeaa44

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

numpy/core/src/umath/ufunc_object.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3524,8 +3524,12 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind,
35243524
Py_INCREF(out);
35253525
}
35263526
}
3527-
/* Allocate the output for when there's no outer iterator */
3528-
else if (out == NULL) {
3527+
else {
3528+
/*
3529+
* Allocate the output for when there's no outer iterator, we always
3530+
* use the outer_iteration path when `out` is passed.
3531+
*/
3532+
assert(out == NULL);
35293533
Py_INCREF(descrs[0]);
35303534
op[0] = out = (PyArrayObject *)PyArray_NewFromDescr(
35313535
&PyArray_Type, descrs[0],

numpy/core/tests/test_ufunc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2155,6 +2155,22 @@ def test_reduce_noncontig_output(self):
21552155
assert_equal(y_base[1,:], y_base_copy[1,:])
21562156
assert_equal(y_base[3,:], y_base_copy[3,:])
21572157

2158+
@pytest.mark.parametrize("with_cast", [True, False])
2159+
def test_reduceat_and_accumulate_out_shape_mismatch(self, with_cast):
2160+
# Should raise an error mentioning "shape" or "size"
2161+
arr = np.arange(5)
2162+
out = np.arange(3) # definitely wrong shape
2163+
if with_cast:
2164+
# If a cast is necessary on the output, we can be sure to use
2165+
# the generic NpyIter (non-fast) path.
2166+
out = out.astype(np.float64)
2167+
2168+
with pytest.raises(ValueError, match="(shape|size)"):
2169+
np.add.reduceat(arr, [0, 3], out=out)
2170+
2171+
with pytest.raises(ValueError, match="(shape|size)"):
2172+
np.add.accumulate(arr, out=out)
2173+
21582174
@pytest.mark.parametrize('out_shape',
21592175
[(), (1,), (3,), (1, 1), (1, 3), (4, 3)])
21602176
@pytest.mark.parametrize('keepdims', [True, False])

0 commit comments

Comments
 (0)