@@ -274,21 +274,20 @@ resolve_implementation_info(PyUFuncObject *ufunc,
274
274
/* Unspecified out always matches (see below for inputs) */
275
275
continue ;
276
276
}
277
+ assert (i == 0 );
277
278
/*
278
- * This is a reduce-like operation, which always have the form
279
- * `(res_DType, op_DType, res_DType)`. If the first and last
280
- * dtype of the loops match, this should be reduce-compatible.
279
+ * This is a reduce-like operation, we enforce that these
280
+ * register with None as the first DType. If a reduction
281
+ * uses the same DType, we will do that promotion.
282
+ * A `(res_DType, op_DType, res_DType)` pattern can make sense
283
+ * in other context as well and could be confusing.
281
284
*/
282
- if (PyTuple_GET_ITEM (curr_dtypes , 0 )
283
- == PyTuple_GET_ITEM (curr_dtypes , 2 )) {
285
+ if (PyTuple_GET_ITEM (curr_dtypes , 0 ) == Py_None ) {
284
286
continue ;
285
287
}
286
- /*
287
- * This should be a reduce, but doesn't follow the reduce
288
- * pattern. So (for now?) consider this not a match.
289
- */
288
+ /* Otherwise, this is not considered a match */
290
289
matches = NPY_FALSE ;
291
- continue ;
290
+ break ;
292
291
}
293
292
294
293
if (resolver_dtype == (PyArray_DTypeMeta * )Py_None ) {
@@ -488,7 +487,7 @@ resolve_implementation_info(PyUFuncObject *ufunc,
488
487
* those defined by the `signature` unmodified).
489
488
*/
490
489
static PyObject *
491
- call_promoter_and_recurse (PyUFuncObject * ufunc , PyObject * promoter ,
490
+ call_promoter_and_recurse (PyUFuncObject * ufunc , PyObject * info ,
492
491
PyArray_DTypeMeta * op_dtypes [], PyArray_DTypeMeta * signature [],
493
492
PyArrayObject * const operands [])
494
493
{
@@ -498,37 +497,51 @@ call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *promoter,
498
497
int promoter_result ;
499
498
PyArray_DTypeMeta * new_op_dtypes [NPY_MAXARGS ];
500
499
501
- if (PyCapsule_CheckExact (promoter )) {
502
- /* We could also go the other way and wrap up the python function... */
503
- PyArrayMethod_PromoterFunction * promoter_function = PyCapsule_GetPointer (
504
- promoter , "numpy._ufunc_promoter" );
505
- if (promoter_function == NULL ) {
500
+ if (info != NULL ) {
501
+ PyObject * promoter = PyTuple_GET_ITEM (info , 1 );
502
+ if (PyCapsule_CheckExact (promoter )) {
503
+ /* We could also go the other way and wrap up the python function... */
504
+ PyArrayMethod_PromoterFunction * promoter_function = PyCapsule_GetPointer (
505
+ promoter , "numpy._ufunc_promoter" );
506
+ if (promoter_function == NULL ) {
507
+ return NULL ;
508
+ }
509
+ promoter_result = promoter_function ((PyObject * )ufunc ,
510
+ op_dtypes , signature , new_op_dtypes );
511
+ }
512
+ else {
513
+ PyErr_SetString (PyExc_NotImplementedError ,
514
+ "Calling python functions for promotion is not implemented." );
506
515
return NULL ;
507
516
}
508
- promoter_result = promoter_function (( PyObject * ) ufunc ,
509
- op_dtypes , signature , new_op_dtypes ) ;
510
- }
511
- else {
512
- PyErr_SetString ( PyExc_NotImplementedError ,
513
- "Calling python functions for promotion is not implemented." );
514
- return NULL ;
515
- }
516
- if ( promoter_result < 0 ) {
517
- return NULL ;
518
- }
519
- /*
520
- * If none of the dtypes changes, we would recurse infinitely, abort.
521
- * (Of course it is nevertheless possible to recurse infinitely.)
522
- */
523
- int dtypes_changed = 0 ;
524
- for ( int i = 0 ; i < nargs ; i ++ ) {
525
- if ( new_op_dtypes [ i ] != op_dtypes [ i ]) {
526
- dtypes_changed = 1 ;
527
- break ;
517
+ if ( promoter_result < 0 ) {
518
+ return NULL ;
519
+ }
520
+ /*
521
+ * If none of the dtypes changes, we would recurse infinitely, abort.
522
+ * (Of course it is nevertheless possible to recurse infinitely.)
523
+ *
524
+ * TODO: We could allow users to signal this directly and also move
525
+ * the call to be (almost immediate). That would call it
526
+ * unnecessarily sometimes, but may allow additional flexibility.
527
+ */
528
+ int dtypes_changed = 0 ;
529
+ for ( int i = 0 ; i < nargs ; i ++ ) {
530
+ if ( new_op_dtypes [ i ] != op_dtypes [ i ]) {
531
+ dtypes_changed = 1 ;
532
+ break ;
533
+ }
534
+ }
535
+ if (! dtypes_changed ) {
536
+ goto finish ;
528
537
}
529
538
}
530
- if (!dtypes_changed ) {
531
- goto finish ;
539
+ else {
540
+ /* Reduction special path */
541
+ new_op_dtypes [0 ] = NPY_DT_NewRef (op_dtypes [1 ]);
542
+ new_op_dtypes [1 ] = NPY_DT_NewRef (op_dtypes [1 ]);
543
+ Py_XINCREF (op_dtypes [2 ]);
544
+ new_op_dtypes [2 ] = op_dtypes [2 ];
532
545
}
533
546
534
547
/*
@@ -788,13 +801,13 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
788
801
789
802
/*
790
803
* At this point `info` is NULL if there is no matching loop, or it is
791
- * a promoter that needs to be used/called:
804
+ * a promoter that needs to be used/called.
805
+ * TODO: It may be nice to find a better reduce-solution, but this way
806
+ * it is a True fallback (not registered so lowest priority)
792
807
*/
793
- if (info != NULL ) {
794
- PyObject * promoter = PyTuple_GET_ITEM (info , 1 );
795
-
808
+ if (info != NULL || op_dtypes [0 ] == NULL ) {
796
809
info = call_promoter_and_recurse (ufunc ,
797
- promoter , op_dtypes , signature , ops );
810
+ info , op_dtypes , signature , ops );
798
811
if (info == NULL && PyErr_Occurred ()) {
799
812
return NULL ;
800
813
}
0 commit comments