Skip to content

Commit e4f2e41

Browse files
committed
BUG: Fix small reduce bug and test string multiply-reduce
1 parent a5e4adf commit e4f2e41

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

numpy/_core/src/umath/reduction.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "array_coercion.h"
2222
#include "array_method.h"
2323
#include "ctors.h"
24+
#include "refcount.h"
2425

2526
#include "numpy/ufuncobject.h"
2627
#include "lowlevel_strided_loops.h"
@@ -438,7 +439,7 @@ PyUFunc_ReduceWrapper(PyArrayMethod_Context *context,
438439
Py_INCREF(result);
439440

440441
if (initial_buf != NULL && PyDataType_REFCHK(PyArray_DESCR(result))) {
441-
PyArray_Item_XDECREF(initial_buf, PyArray_DESCR(result));
442+
PyArray_ClearBuffer(PyArray_DESCR(result), initial_buf, 0, 1, 1);
442443
}
443444
PyMem_FREE(initial_buf);
444445
NPY_AUXDATA_FREE(auxdata);
@@ -450,7 +451,7 @@ PyUFunc_ReduceWrapper(PyArrayMethod_Context *context,
450451

451452
fail:
452453
if (initial_buf != NULL && PyDataType_REFCHK(op_dtypes[0])) {
453-
PyArray_Item_XDECREF(initial_buf, op_dtypes[0]);
454+
PyArray_ClearBuffer(op_dtypes[0], initial_buf, 0, 1, 1);
454455
}
455456
PyMem_FREE(initial_buf);
456457
NPY_AUXDATA_FREE(auxdata);

numpy/_core/tests/test_stringdtype.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,16 @@ def test_add_promoter_reduce():
755755
np.add.reduce(np.array(["a", "b"], dtype="U"), dtype=np.dtypes.StringDType)
756756

757757

758+
def test_multiply_reduce():
759+
# At the time of writing (NumPy 2.0) this is very limited (and rather
760+
# ridiculous anyway). But it works and actually makes some sense...
761+
# (NumPy does not allow non-scalar initial values)
762+
repeats = np.array([2, 3, 4])
763+
val = "school-🚌"
764+
res = np.multiply.reduce(repeats, initial=val, dtype=np.dtypes.StringDType)
765+
assert res == val * np.prod(repeats)
766+
767+
758768
@pytest.mark.parametrize("use_out", [True, False])
759769
@pytest.mark.parametrize("other", [2, [2, 1, 3, 4, 1, 3]])
760770
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)