Skip to content

Commit 405c6ee

Browse files
authored
Merge pull request numpy#18905 from seberg/ufunc-refactor-2021
MAINT: Refactor reductions to use NEP 43 style dispatching/promotion
2 parents dd2eaaa + 27f3b03 commit 405c6ee

16 files changed

+866
-521
lines changed

numpy/core/src/multiarray/array_method.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,13 @@ _masked_stridedloop_data_free(NpyAuxData *auxdata)
780780
* This function wraps a regular unmasked strided-loop as a
781781
* masked strided-loop, only calling the function for elements
782782
* where the mask is True.
783+
*
784+
* TODO: Reductions also use this code to implement masked reductions.
785+
* Before consolidating them, reductions had a special case for
786+
* broadcasts: when the mask stride was 0 the code does not check all
787+
* elements as `npy_memchr` currently does.
788+
* It may be worthwhile to add such an optimization again if broadcasted
789+
* masks are common enough.
783790
*/
784791
static int
785792
generic_masked_strided_loop(PyArrayMethod_Context *context,

numpy/core/src/multiarray/array_method.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ typedef enum {
2121
NPY_METH_NO_FLOATINGPOINT_ERRORS = 1 << 2,
2222
/* Whether the method supports unaligned access (not runtime) */
2323
NPY_METH_SUPPORTS_UNALIGNED = 1 << 3,
24+
/*
25+
* Private flag for now for *logic* functions. The logical functions
26+
* `logical_or` and `logical_and` can always cast the inputs to booleans
27+
* "safely" (because that is how the cast to bool is defined).
28+
* @seberg: I am not sure this is the best way to handle this, so its
29+
* private for now (also it is very limited anyway).
30+
* There is one "exception". NA aware dtypes cannot cast to bool
31+
* (hopefully), so the `??->?` loop should error even with this flag.
32+
* But a second NA fallback loop will be necessary.
33+
*/
34+
_NPY_METH_FORCE_CAST_INPUTS = 1 << 17,
2435

2536
/* All flags which can change at runtime */
2637
NPY_METH_RUNTIME_FLAGS = (

numpy/core/src/umath/_scaled_float_dtype.c

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,42 @@ float_to_from_sfloat_resolve_descriptors(
398398
}
399399

400400

401+
/*
402+
* Cast to boolean (for testing the logical functions a bit better).
403+
*/
404+
static int
405+
cast_sfloat_to_bool(PyArrayMethod_Context *NPY_UNUSED(context),
406+
char *const data[], npy_intp const dimensions[],
407+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
408+
{
409+
npy_intp N = dimensions[0];
410+
char *in = data[0];
411+
char *out = data[1];
412+
for (npy_intp i = 0; i < N; i++) {
413+
*(npy_bool *)out = *(double *)in != 0;
414+
in += strides[0];
415+
out += strides[1];
416+
}
417+
return 0;
418+
}
419+
420+
static NPY_CASTING
421+
sfloat_to_bool_resolve_descriptors(
422+
PyArrayMethodObject *NPY_UNUSED(self),
423+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
424+
PyArray_Descr *given_descrs[2],
425+
PyArray_Descr *loop_descrs[2])
426+
{
427+
Py_INCREF(given_descrs[0]);
428+
loop_descrs[0] = given_descrs[0];
429+
if (loop_descrs[0] == NULL) {
430+
return -1;
431+
}
432+
loop_descrs[1] = PyArray_DescrFromType(NPY_BOOL); /* cannot fail */
433+
return NPY_UNSAFE_CASTING;
434+
}
435+
436+
401437
static int
402438
init_casts(void)
403439
{
@@ -453,6 +489,22 @@ init_casts(void)
453489
return -1;
454490
}
455491

492+
slots[0].slot = NPY_METH_resolve_descriptors;
493+
slots[0].pfunc = &sfloat_to_bool_resolve_descriptors;
494+
slots[1].slot = NPY_METH_strided_loop;
495+
slots[1].pfunc = &cast_sfloat_to_bool;
496+
slots[2].slot = 0;
497+
slots[2].pfunc = NULL;
498+
499+
spec.name = "sfloat_to_bool_cast";
500+
dtypes[0] = &PyArray_SFloatDType;
501+
dtypes[1] = PyArray_DTypeFromTypeNum(NPY_BOOL);
502+
Py_DECREF(dtypes[1]); /* immortal anyway */
503+
504+
if (PyArray_AddCastingImplementation_FromSpec(&spec, 0)) {
505+
return -1;
506+
}
507+
456508
return 0;
457509
}
458510

numpy/core/src/umath/dispatching.c

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,39 @@ resolve_implementation_info(PyUFuncObject *ufunc,
267267
* the subclass should be considered a better match
268268
* (subclasses are always more specific).
269269
*/
270+
/* Whether this (normally output) dtype was specified at all */
271+
if (op_dtypes[i] == NULL) {
272+
/*
273+
* When DType is completely unspecified, prefer abstract
274+
* over concrete, assuming it will resolve.
275+
* Furthermore, we cannot decide which abstract/None
276+
* is "better", only concrete ones which are subclasses
277+
* of Abstract ones are defined as worse.
278+
*/
279+
npy_bool prev_is_concrete = NPY_FALSE;
280+
npy_bool new_is_concrete = NPY_FALSE;
281+
if ((prev_dtype != Py_None) &&
282+
!NPY_DT_is_abstract((PyArray_DTypeMeta *)prev_dtype)) {
283+
prev_is_concrete = NPY_TRUE;
284+
}
285+
if ((new_dtype != Py_None) &&
286+
!NPY_DT_is_abstract((PyArray_DTypeMeta *)new_dtype)) {
287+
new_is_concrete = NPY_TRUE;
288+
}
289+
if (prev_is_concrete == new_is_concrete) {
290+
best = -1;
291+
}
292+
else if (prev_is_concrete) {
293+
unambiguously_equally_good = 0;
294+
best = 1;
295+
}
296+
else {
297+
unambiguously_equally_good = 0;
298+
best = 0;
299+
}
300+
}
270301
/* If either is None, the other is strictly more specific */
271-
if (prev_dtype == Py_None) {
302+
else if (prev_dtype == Py_None) {
272303
unambiguously_equally_good = 0;
273304
best = 1;
274305
}
@@ -289,13 +320,29 @@ resolve_implementation_info(PyUFuncObject *ufunc,
289320
*/
290321
best = -1;
291322
}
323+
else if (!NPY_DT_is_abstract((PyArray_DTypeMeta *)prev_dtype)) {
324+
/* old is not abstract, so better (both not possible) */
325+
unambiguously_equally_good = 0;
326+
best = 0;
327+
}
328+
else if (!NPY_DT_is_abstract((PyArray_DTypeMeta *)new_dtype)) {
329+
/* new is not abstract, so better (both not possible) */
330+
unambiguously_equally_good = 0;
331+
best = 1;
332+
}
292333
/*
293-
* TODO: Unreachable, but we will need logic for abstract
294-
* DTypes to decide if one is a subclass of the other
295-
* (And their subclass relation is well defined.)
334+
* TODO: This will need logic for abstract DTypes to decide if
335+
* one is a subclass of the other (And their subclass
336+
* relation is well defined). For now, we bail out
337+
* in cas someone manages to get here.
296338
*/
297339
else {
298-
assert(0);
340+
PyErr_SetString(PyExc_NotImplementedError,
341+
"deciding which one of two abstract dtypes is "
342+
"a better match is not yet implemented. This "
343+
"will pick the better (or bail) in the future.");
344+
*out_info = NULL;
345+
return -1;
299346
}
300347

301348
if ((current_best != -1) && (current_best != best)) {
@@ -612,6 +659,35 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
612659
}
613660
return info;
614661
}
662+
else if (info == NULL && op_dtypes[0] == NULL) {
663+
/*
664+
* If we have a reduction, fill in the unspecified input/array
665+
* assuming it should have the same dtype as the operand input
666+
* (or the output one if given).
667+
* Then, try again. In some cases, this will choose different
668+
* paths, such as `ll->?` instead of an `??->?` loop for `np.equal`
669+
* when the input is `.l->.` (`.` meaning undefined). This will
670+
* then cause an error. But cast to `?` would always lose
671+
* information, and in many cases important information:
672+
*
673+
* ```python
674+
* from operator import eq
675+
* from functools import reduce
676+
*
677+
* reduce(eq, [1, 2, 3]) != reduce(eq, [True, True, True])
678+
* ```
679+
*
680+
* The special cases being `logical_(and|or|xor)` which can always
681+
* cast to boolean ahead of time and still give the right answer
682+
* (unsafe cast to bool is fine here). We special case these at
683+
* the time of this comment (NumPy 1.21).
684+
*/
685+
assert(ufunc->nin == 2 && ufunc->nout == 1);
686+
op_dtypes[0] = op_dtypes[2] != NULL ? op_dtypes[2] : op_dtypes[1];
687+
Py_INCREF(op_dtypes[0]);
688+
return promote_and_get_info_and_ufuncimpl(ufunc,
689+
ops, signature, op_dtypes, allow_legacy_promotion, 1);
690+
}
615691
}
616692

617693
/*
@@ -743,3 +819,94 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
743819

744820
return method;
745821
}
822+
823+
824+
/*
825+
* Special promoter for the logical ufuncs. The logical ufuncs can always
826+
* use the ??->? and still get the correct output (as long as the output
827+
* is not supposed to be `object`).
828+
*/
829+
static int
830+
logical_ufunc_promoter(PyUFuncObject *NPY_UNUSED(ufunc),
831+
PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
832+
PyArray_DTypeMeta *new_op_dtypes[])
833+
{
834+
/*
835+
* If we find any object DType at all, we currently force to object.
836+
* However, if the output is specified and not object, there is no point,
837+
* it should be just as well to cast the input rather than doing the
838+
* unsafe out cast.
839+
*/
840+
int force_object = 0;
841+
842+
for (int i = 0; i < 3; i++) {
843+
PyArray_DTypeMeta *item;
844+
if (signature[i] != NULL) {
845+
item = signature[i];
846+
Py_INCREF(item);
847+
if (item->type_num == NPY_OBJECT) {
848+
force_object = 1;
849+
}
850+
}
851+
else {
852+
/* Always override to boolean */
853+
item = PyArray_DTypeFromTypeNum(NPY_BOOL);
854+
if (op_dtypes[i] != NULL && op_dtypes[i]->type_num == NPY_OBJECT) {
855+
force_object = 1;
856+
}
857+
}
858+
new_op_dtypes[i] = item;
859+
}
860+
861+
if (!force_object || (op_dtypes[2] != NULL
862+
&& op_dtypes[2]->type_num != NPY_OBJECT)) {
863+
return 0;
864+
}
865+
/*
866+
* Actually, we have to use the OBJECT loop after all, set all we can
867+
* to object (that might not work out, but try).
868+
*
869+
* NOTE: Change this to check for `op_dtypes[0] == NULL` to STOP
870+
* returning `object` for `np.logical_and.reduce(obj_arr)`
871+
* which will also affect `np.all` and `np.any`!
872+
*/
873+
for (int i = 0; i < 3; i++) {
874+
if (signature[i] != NULL) {
875+
continue;
876+
}
877+
Py_SETREF(new_op_dtypes[i], PyArray_DTypeFromTypeNum(NPY_OBJECT));
878+
}
879+
return 0;
880+
}
881+
882+
883+
NPY_NO_EXPORT int
884+
install_logical_ufunc_promoter(PyObject *ufunc)
885+
{
886+
if (PyObject_Type(ufunc) != (PyObject *)&PyUFunc_Type) {
887+
PyErr_SetString(PyExc_RuntimeError,
888+
"internal numpy array, logical ufunc was not a ufunc?!");
889+
return -1;
890+
}
891+
PyObject *dtype_tuple = PyTuple_Pack(3,
892+
&PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type, NULL);
893+
if (dtype_tuple == NULL) {
894+
return -1;
895+
}
896+
PyObject *promoter = PyCapsule_New(&logical_ufunc_promoter,
897+
"numpy._ufunc_promoter", NULL);
898+
if (promoter == NULL) {
899+
Py_DECREF(dtype_tuple);
900+
return -1;
901+
}
902+
903+
PyObject *info = PyTuple_Pack(2, dtype_tuple, promoter);
904+
Py_DECREF(dtype_tuple);
905+
Py_DECREF(promoter);
906+
if (info == NULL) {
907+
return -1;
908+
}
909+
910+
return PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0);
911+
}
912+

numpy/core/src/umath/dispatching.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ NPY_NO_EXPORT PyObject *
2626
add_and_return_legacy_wrapping_ufunc_loop(PyUFuncObject *ufunc,
2727
PyArray_DTypeMeta *operation_dtypes[], int ignore_duplicate);
2828

29+
NPY_NO_EXPORT int
30+
install_logical_ufunc_promoter(PyObject *ufunc);
31+
32+
2933
#endif /*_NPY_DISPATCHING_H */

numpy/core/src/umath/legacy_array_method.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,25 @@ PyArray_NewLegacyWrappingArrayMethod(PyUFuncObject *ufunc,
217217
*/
218218
int any_output_flexible = 0;
219219
NPY_ARRAYMETHOD_FLAGS flags = 0;
220+
if (ufunc->nargs == 3 &&
221+
signature[0]->type_num == NPY_BOOL &&
222+
signature[1]->type_num == NPY_BOOL &&
223+
signature[2]->type_num == NPY_BOOL && (
224+
strcmp(ufunc->name, "logical_or") == 0 ||
225+
strcmp(ufunc->name, "logical_and") == 0 ||
226+
strcmp(ufunc->name, "logical_xor") == 0)) {
227+
/*
228+
* This is a logical ufunc, and the `??->?` loop`. It is always OK
229+
* to cast any input to bool, because that cast is defined by
230+
* truthiness.
231+
* This allows to ensure two things:
232+
* 1. `np.all`/`np.any` know that force casting the input is OK
233+
* (they must do this since there are no `?l->?`, etc. loops)
234+
* 2. The logical functions automatically work for any DType
235+
* implementing a cast to boolean.
236+
*/
237+
flags = _NPY_METH_FORCE_CAST_INPUTS;
238+
}
220239

221240
for (int i = 0; i < ufunc->nin+ufunc->nout; i++) {
222241
if (signature[i]->singleton->flags & (

0 commit comments

Comments
 (0)