Skip to content

Commit c3f091d

Browse files
authored
Merge pull request numpy#26090 from seberg/minimal-reduce-fix
API: Require reduce promoters to start with None to match
2 parents 71ab906 + e4f2e41 commit c3f091d

File tree

3 files changed

+78
-45
lines changed

3 files changed

+78
-45
lines changed

numpy/_core/src/umath/dispatching.c

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -274,21 +274,20 @@ resolve_implementation_info(PyUFuncObject *ufunc,
274274
/* Unspecified out always matches (see below for inputs) */
275275
continue;
276276
}
277+
assert(i == 0);
277278
/*
278-
* This is a reduce-like operation, which always have the form
279-
* `(res_DType, op_DType, res_DType)`. If the first and last
280-
* dtype of the loops match, this should be reduce-compatible.
279+
* This is a reduce-like operation, we enforce that these
280+
* register with None as the first DType. If a reduction
281+
* uses the same DType, we will do that promotion.
282+
* A `(res_DType, op_DType, res_DType)` pattern can make sense
283+
* in other context as well and could be confusing.
281284
*/
282-
if (PyTuple_GET_ITEM(curr_dtypes, 0)
283-
== PyTuple_GET_ITEM(curr_dtypes, 2)) {
285+
if (PyTuple_GET_ITEM(curr_dtypes, 0) == Py_None) {
284286
continue;
285287
}
286-
/*
287-
* This should be a reduce, but doesn't follow the reduce
288-
* pattern. So (for now?) consider this not a match.
289-
*/
288+
/* Otherwise, this is not considered a match */
290289
matches = NPY_FALSE;
291-
continue;
290+
break;
292291
}
293292

294293
if (resolver_dtype == (PyArray_DTypeMeta *)Py_None) {
@@ -488,7 +487,7 @@ resolve_implementation_info(PyUFuncObject *ufunc,
488487
* those defined by the `signature` unmodified).
489488
*/
490489
static PyObject *
491-
call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter,
490+
call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *info,
492491
PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
493492
PyArrayObject *const operands[])
494493
{
@@ -498,37 +497,51 @@ call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter,
498497
int promoter_result;
499498
PyArray_DTypeMeta *new_op_dtypes[NPY_MAXARGS];
500499

501-
if (PyCapsule_CheckExact(promoter)) {
502-
/* We could also go the other way and wrap up the python function... */
503-
PyArrayMethod_PromoterFunction *promoter_function = PyCapsule_GetPointer(
504-
promoter, "numpy._ufunc_promoter");
505-
if (promoter_function == NULL) {
500+
if (info != NULL) {
501+
PyObject *promoter = PyTuple_GET_ITEM(info, 1);
502+
if (PyCapsule_CheckExact(promoter)) {
503+
/* We could also go the other way and wrap up the python function... */
504+
PyArrayMethod_PromoterFunction *promoter_function = PyCapsule_GetPointer(
505+
promoter, "numpy._ufunc_promoter");
506+
if (promoter_function == NULL) {
507+
return NULL;
508+
}
509+
promoter_result = promoter_function((PyObject *)ufunc,
510+
op_dtypes, signature, new_op_dtypes);
511+
}
512+
else {
513+
PyErr_SetString(PyExc_NotImplementedError,
514+
"Calling python functions for promotion is not implemented.");
506515
return NULL;
507516
}
508-
promoter_result = promoter_function((PyObject *)ufunc,
509-
op_dtypes, signature, new_op_dtypes);
510-
}
511-
else {
512-
PyErr_SetString(PyExc_NotImplementedError,
513-
"Calling python functions for promotion is not implemented.");
514-
return NULL;
515-
}
516-
if (promoter_result < 0) {
517-
return NULL;
518-
}
519-
/*
520-
* If none of the dtypes changes, we would recurse infinitely, abort.
521-
* (Of course it is nevertheless possible to recurse infinitely.)
522-
*/
523-
int dtypes_changed = 0;
524-
for (int i = 0; i < nargs; i++) {
525-
if (new_op_dtypes[i] != op_dtypes[i]) {
526-
dtypes_changed = 1;
527-
break;
517+
if (promoter_result < 0) {
518+
return NULL;
519+
}
520+
/*
521+
* If none of the dtypes changes, we would recurse infinitely, abort.
522+
* (Of course it is nevertheless possible to recurse infinitely.)
523+
*
524+
* TODO: We could allow users to signal this directly and also move
525+
* the call to be (almost immediate). That would call it
526+
* unnecessarily sometimes, but may allow additional flexibility.
527+
*/
528+
int dtypes_changed = 0;
529+
for (int i = 0; i < nargs; i++) {
530+
if (new_op_dtypes[i] != op_dtypes[i]) {
531+
dtypes_changed = 1;
532+
break;
533+
}
534+
}
535+
if (!dtypes_changed) {
536+
goto finish;
528537
}
529538
}
530-
if (!dtypes_changed) {
531-
goto finish;
539+
else {
540+
/* Reduction special path */
541+
new_op_dtypes[0] = NPY_DT_NewRef(op_dtypes[1]);
542+
new_op_dtypes[1] = NPY_DT_NewRef(op_dtypes[1]);
543+
Py_XINCREF(op_dtypes[2]);
544+
new_op_dtypes[2] = op_dtypes[2];
532545
}
533546

534547
/*
@@ -788,13 +801,13 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
788801

789802
/*
790803
* At this point `info` is NULL if there is no matching loop, or it is
791-
* a promoter that needs to be used/called:
804+
* a promoter that needs to be used/called.
805+
* TODO: It may be nice to find a better reduce-solution, but this way
806+
* it is a True fallback (not registered so lowest priority)
792807
*/
793-
if (info != NULL) {
794-
PyObject *promoter = PyTuple_GET_ITEM(info, 1);
795-
808+
if (info != NULL || op_dtypes[0] == NULL) {
796809
info = call_promoter_and_recurse(ufunc,
797-
promoter, op_dtypes, signature, ops);
810+
info, op_dtypes, signature, ops);
798811
if (info == NULL && PyErr_Occurred()) {
799812
return NULL;
800813
}

numpy/_core/src/umath/reduction.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "array_coercion.h"
2222
#include "array_method.h"
2323
#include "ctors.h"
24+
#include "refcount.h"
2425

2526
#include "numpy/ufuncobject.h"
2627
#include "lowlevel_strided_loops.h"
@@ -438,7 +439,7 @@ PyUFunc_ReduceWrapper(PyArrayMethod_Context *context,
438439
Py_INCREF(result);
439440

440441
if (initial_buf != NULL && PyDataType_REFCHK(PyArray_DESCR(result))) {
441-
PyArray_Item_XDECREF(initial_buf, PyArray_DESCR(result));
442+
PyArray_ClearBuffer(PyArray_DESCR(result), initial_buf, 0, 1, 1);
442443
}
443444
PyMem_FREE(initial_buf);
444445
NPY_AUXDATA_FREE(auxdata);
@@ -450,7 +451,7 @@ PyUFunc_ReduceWrapper(PyArrayMethod_Context *context,
450451

451452
fail:
452453
if (initial_buf != NULL && PyDataType_REFCHK(op_dtypes[0])) {
453-
PyArray_Item_XDECREF(initial_buf, op_dtypes[0]);
454+
PyArray_ClearBuffer(op_dtypes[0], initial_buf, 0, 1, 1);
454455
}
455456
PyMem_FREE(initial_buf);
456457
NPY_AUXDATA_FREE(auxdata);

numpy/_core/tests/test_stringdtype.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,25 @@ def test_add_promoter(string_list):
746746
assert_array_equal(arr + op, rresult)
747747

748748

749+
def test_add_promoter_reduce():
750+
# Exact TypeError could change, but ensure StringDtype doesn't match
751+
with pytest.raises(TypeError, match="the resolved dtypes are not"):
752+
np.add.reduce(np.array(["a", "b"], dtype="U"))
753+
754+
# On the other hand, using `dtype=T` in the *ufunc* should work.
755+
np.add.reduce(np.array(["a", "b"], dtype="U"), dtype=np.dtypes.StringDType)
756+
757+
758+
def test_multiply_reduce():
759+
# At the time of writing (NumPy 2.0) this is very limited (and rather
760+
# ridiculous anyway). But it works and actually makes some sense...
761+
# (NumPy does not allow non-scalar initial values)
762+
repeats = np.array([2, 3, 4])
763+
val = "school-🚌"
764+
res = np.multiply.reduce(repeats, initial=val, dtype=np.dtypes.StringDType)
765+
assert res == val * np.prod(repeats)
766+
767+
749768
@pytest.mark.parametrize("use_out", [True, False])
750769
@pytest.mark.parametrize("other", [2, [2, 1, 3, 4, 1, 3]])
751770
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)