Skip to content

Commit 9d49b0d

Browse files
committed
API: Require reduce promoters to start with None to match
But additionally, we add an (implicit) fallback promoter for the reduction case that fills in the first dtype with the second one.
1 parent 56dab50 commit 9d49b0d

File tree

1 file changed

+56
-43
lines changed

1 file changed

+56
-43
lines changed

numpy/_core/src/umath/dispatching.c

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -274,21 +274,20 @@ resolve_implementation_info(PyUFuncObject *ufunc,
274274
/* Unspecified out always matches (see below for inputs) */
275275
continue;
276276
}
277+
assert(i == 0);
277278
/*
278-
* This is a reduce-like operation, which always have the form
279-
* `(res_DType, op_DType, res_DType)`. If the first and last
280-
* dtype of the loops match, this should be reduce-compatible.
279+
* This is a reduce-like operation, we enforce that these
280+
* register with None as the first DType. If a reduction
281+
* uses the same DType, we will do that promotion.
282+
* A `(res_DType, op_DType, res_DType)` pattern can make sense
283+
* in other context as well and could be confusing.
281284
*/
282-
if (PyTuple_GET_ITEM(curr_dtypes, 0)
283-
== PyTuple_GET_ITEM(curr_dtypes, 2)) {
285+
if (PyTuple_GET_ITEM(curr_dtypes, 0) == Py_None) {
284286
continue;
285287
}
286-
/*
287-
* This should be a reduce, but doesn't follow the reduce
288-
* pattern. So (for now?) consider this not a match.
289-
*/
288+
/* Otherwise, this is not considered a match */
290289
matches = NPY_FALSE;
291-
continue;
290+
break;
292291
}
293292

294293
if (resolver_dtype == (PyArray_DTypeMeta *)Py_None) {
@@ -488,7 +487,7 @@ resolve_implementation_info(PyUFuncObject *ufunc,
488487
* those defined by the `signature` unmodified).
489488
*/
490489
static PyObject *
491-
call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter,
490+
call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *info,
492491
PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
493492
PyArrayObject *const operands[])
494493
{
@@ -498,37 +497,51 @@ call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter,
498497
int promoter_result;
499498
PyArray_DTypeMeta *new_op_dtypes[NPY_MAXARGS];
500499

501-
if (PyCapsule_CheckExact(promoter)) {
502-
/* We could also go the other way and wrap up the python function... */
503-
PyArrayMethod_PromoterFunction *promoter_function = PyCapsule_GetPointer(
504-
promoter, "numpy._ufunc_promoter");
505-
if (promoter_function == NULL) {
500+
if (info != NULL) {
501+
PyObject *promoter = PyTuple_GET_ITEM(info, 1);
502+
if (PyCapsule_CheckExact(promoter)) {
503+
/* We could also go the other way and wrap up the python function... */
504+
PyArrayMethod_PromoterFunction *promoter_function = PyCapsule_GetPointer(
505+
promoter, "numpy._ufunc_promoter");
506+
if (promoter_function == NULL) {
507+
return NULL;
508+
}
509+
promoter_result = promoter_function((PyObject *)ufunc,
510+
op_dtypes, signature, new_op_dtypes);
511+
}
512+
else {
513+
PyErr_SetString(PyExc_NotImplementedError,
514+
"Calling python functions for promotion is not implemented.");
506515
return NULL;
507516
}
508-
promoter_result = promoter_function((PyObject *)ufunc,
509-
op_dtypes, signature, new_op_dtypes);
510-
}
511-
else {
512-
PyErr_SetString(PyExc_NotImplementedError,
513-
"Calling python functions for promotion is not implemented.");
514-
return NULL;
515-
}
516-
if (promoter_result < 0) {
517-
return NULL;
518-
}
519-
/*
520-
* If none of the dtypes changes, we would recurse infinitely, abort.
521-
* (Of course it is nevertheless possible to recurse infinitely.)
522-
*/
523-
int dtypes_changed = 0;
524-
for (int i = 0; i < nargs; i++) {
525-
if (new_op_dtypes[i] != op_dtypes[i]) {
526-
dtypes_changed = 1;
527-
break;
517+
if (promoter_result < 0) {
518+
return NULL;
519+
}
520+
/*
521+
* If none of the dtypes changes, we would recurse infinitely, abort.
522+
* (Of course it is nevertheless possible to recurse infinitely.)
523+
*
524+
* TODO: We could allow users to signal this directly and also move
525+
* the call to be (almost immediate). That would call it
526+
* unnecessarily sometimes, but may allow additional flexibility.
527+
*/
528+
int dtypes_changed = 0;
529+
for (int i = 0; i < nargs; i++) {
530+
if (new_op_dtypes[i] != op_dtypes[i]) {
531+
dtypes_changed = 1;
532+
break;
533+
}
534+
}
535+
if (!dtypes_changed) {
536+
goto finish;
528537
}
529538
}
530-
if (!dtypes_changed) {
531-
goto finish;
539+
else {
540+
/* Reduction special path */
541+
new_op_dtypes[0] = NPY_DT_NewRef(op_dtypes[1]);
542+
new_op_dtypes[1] = NPY_DT_NewRef(op_dtypes[1]);
543+
Py_XINCREF(op_dtypes[2]);
544+
new_op_dtypes[2] = op_dtypes[2];
532545
}
533546

534547
/*
@@ -788,13 +801,13 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
788801

789802
/*
790803
* At this point `info` is NULL if there is no matching loop, or it is
791-
* a promoter that needs to be used/called:
804+
* a promoter that needs to be used/called.
805+
* TODO: It may be nice to find a better reduce-solution, but this way
806+
* it is a True fallback (not registered so lowest priority)
792807
*/
793-
if (info != NULL) {
794-
PyObject *promoter = PyTuple_GET_ITEM(info, 1);
795-
808+
if (info != NULL || op_dtypes[0] == NULL) {
796809
info = call_promoter_and_recurse(ufunc,
797-
promoter, op_dtypes, signature, ops);
810+
info, op_dtypes, signature, ops);
798811
if (info == NULL && PyErr_Occurred()) {
799812
return NULL;
800813
}

0 commit comments

Comments
 (0)