Skip to content

Commit f25905b

Browse files
authored
Merge pull request numpy#19578 from seberg/scaled_float_ufunc_loops
TST: Add basic tests for custom DType (scaled float) ufuncs
2 parents 6822c09 + 9d4115e commit f25905b

File tree

5 files changed

+279
-7
lines changed

5 files changed

+279
-7
lines changed

numpy/core/src/multiarray/array_method.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,10 +833,12 @@ generic_masked_strided_loop(PyArrayMethod_Context *context,
833833

834834

835835
/*
836-
* Identical to the `get_loop` functions and wraps it. This adds support
837-
* to a boolean mask being passed in as a last, additional, operand.
838-
* The wrapped loop will only be called for unmasked elements.
839-
* (Does not support `move_references` or inner dimensions!)
836+
* Fetches a strided-loop function that supports a boolean mask as additional
837+
* (last) operand to the strided-loop. It is otherwise largely identical to
838+
* the `get_loop` method which it wraps.
839+
* This is the core implementation for the ufunc `where=...` keyword argument.
840+
*
841+
* NOTE: This function does not support `move_references` or inner dimensions.
840842
*/
841843
NPY_NO_EXPORT int
842844
PyArrayMethod_GetMaskedStridedLoop(

numpy/core/src/umath/_scaled_float_dtype.c

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "numpy/npy_math.h"
2424
#include "convert_datatype.h"
2525
#include "dtypemeta.h"
26+
#include "dispatching.h"
2627

2728

2829
typedef struct {
@@ -456,6 +457,225 @@ init_casts(void)
456457
}
457458

458459

460+
/*
461+
* We also wish to test very simple ufunc functionality. So create two
462+
* ufunc loops:
463+
* 1. Multiplication, which can multiply the factors and work with that.
464+
* 2. Addition, which needs to use the common instance, and runs into
465+
* cast safety subtleties since we will implement it without an additional
466+
* cast.
467+
*
468+
* NOTE: When first writing this, promotion did not exist for new-style loops,
469+
* if it exists, we could use promotion to implement double * sfloat.
470+
*/
471+
static int
472+
multiply_sfloats(PyArrayMethod_Context *NPY_UNUSED(context),
473+
char *const data[], npy_intp const dimensions[],
474+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
475+
{
476+
npy_intp N = dimensions[0];
477+
char *in1 = data[0];
478+
char *in2 = data[1];
479+
char *out = data[2];
480+
for (npy_intp i = 0; i < N; i++) {
481+
*(double *)out = *(double *)in1 * *(double *)in2;
482+
in1 += strides[0];
483+
in2 += strides[1];
484+
out += strides[2];
485+
}
486+
return 0;
487+
}
488+
489+
490+
static NPY_CASTING
491+
multiply_sfloats_resolve_descriptors(
492+
PyArrayMethodObject *NPY_UNUSED(self),
493+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[3]),
494+
PyArray_Descr *given_descrs[3],
495+
PyArray_Descr *loop_descrs[3])
496+
{
497+
/*
498+
* Multiply the scaling for the result. If the result was passed in we
499+
* simply ignore it and let the casting machinery fix it up here.
500+
*/
501+
double factor = ((PyArray_SFloatDescr *)given_descrs[1])->scaling;
502+
loop_descrs[2] = sfloat_scaled_copy(
503+
(PyArray_SFloatDescr *)given_descrs[0], factor);
504+
if (loop_descrs[2] == 0) {
505+
return -1;
506+
}
507+
Py_INCREF(given_descrs[0]);
508+
loop_descrs[0] = given_descrs[0];
509+
Py_INCREF(given_descrs[1]);
510+
loop_descrs[1] = given_descrs[1];
511+
return NPY_NO_CASTING;
512+
}
513+
514+
515+
/*
516+
* Unlike the multiplication implementation above, this loops deals with
517+
* scaling (casting) internally. This allows to test some different paths.
518+
*/
519+
static int
520+
add_sfloats(PyArrayMethod_Context *context,
521+
char *const data[], npy_intp const dimensions[],
522+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
523+
{
524+
double fin1 = ((PyArray_SFloatDescr *)context->descriptors[0])->scaling;
525+
double fin2 = ((PyArray_SFloatDescr *)context->descriptors[1])->scaling;
526+
double fout = ((PyArray_SFloatDescr *)context->descriptors[2])->scaling;
527+
528+
double fact1 = fin1 / fout;
529+
double fact2 = fin2 / fout;
530+
if (check_factor(fact1) < 0) {
531+
return -1;
532+
}
533+
if (check_factor(fact2) < 0) {
534+
return -1;
535+
}
536+
537+
npy_intp N = dimensions[0];
538+
char *in1 = data[0];
539+
char *in2 = data[1];
540+
char *out = data[2];
541+
for (npy_intp i = 0; i < N; i++) {
542+
*(double *)out = (*(double *)in1 * fact1) + (*(double *)in2 * fact2);
543+
in1 += strides[0];
544+
in2 += strides[1];
545+
out += strides[2];
546+
}
547+
return 0;
548+
}
549+
550+
551+
static NPY_CASTING
552+
add_sfloats_resolve_descriptors(
553+
PyArrayMethodObject *NPY_UNUSED(self),
554+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[3]),
555+
PyArray_Descr *given_descrs[3],
556+
PyArray_Descr *loop_descrs[3])
557+
{
558+
/*
559+
* Here we accept an output descriptor (the inner loop can deal with it),
560+
* if none is given, we use the "common instance":
561+
*/
562+
if (given_descrs[2] == NULL) {
563+
loop_descrs[2] = sfloat_common_instance(
564+
given_descrs[0], given_descrs[1]);
565+
if (loop_descrs[2] == 0) {
566+
return -1;
567+
}
568+
}
569+
else {
570+
Py_INCREF(given_descrs[2]);
571+
loop_descrs[2] = given_descrs[2];
572+
}
573+
Py_INCREF(given_descrs[0]);
574+
loop_descrs[0] = given_descrs[0];
575+
Py_INCREF(given_descrs[1]);
576+
loop_descrs[1] = given_descrs[1];
577+
578+
/* If the factors mismatch, we do implicit casting inside the ufunc! */
579+
double fin1 = ((PyArray_SFloatDescr *)loop_descrs[0])->scaling;
580+
double fin2 = ((PyArray_SFloatDescr *)loop_descrs[1])->scaling;
581+
double fout = ((PyArray_SFloatDescr *)loop_descrs[2])->scaling;
582+
583+
if (fin1 == fout && fin2 == fout) {
584+
return NPY_NO_CASTING;
585+
}
586+
if (npy_fabs(fin1) == npy_fabs(fout) && npy_fabs(fin2) == npy_fabs(fout)) {
587+
return NPY_EQUIV_CASTING;
588+
}
589+
return NPY_SAME_KIND_CASTING;
590+
}
591+
592+
593+
static int
594+
add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
595+
{
596+
PyObject *mod = PyImport_ImportModule("numpy");
597+
if (mod == NULL) {
598+
return -1;
599+
}
600+
PyObject *ufunc = PyObject_GetAttrString(mod, ufunc_name);
601+
Py_DECREF(mod);
602+
if (!PyObject_TypeCheck(ufunc, &PyUFunc_Type)) {
603+
Py_DECREF(ufunc);
604+
PyErr_Format(PyExc_TypeError,
605+
"numpy.%s was not a ufunc!", ufunc_name);
606+
return -1;
607+
}
608+
PyObject *dtype_tup = PyArray_TupleFromItems(
609+
3, (PyObject **)bmeth->dtypes, 0);
610+
if (dtype_tup == NULL) {
611+
Py_DECREF(ufunc);
612+
return -1;
613+
}
614+
PyObject *info = PyTuple_Pack(2, dtype_tup, bmeth->method);
615+
Py_DECREF(dtype_tup);
616+
if (info == NULL) {
617+
Py_DECREF(ufunc);
618+
return -1;
619+
}
620+
int res = PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0);
621+
Py_DECREF(ufunc);
622+
Py_DECREF(info);
623+
return res;
624+
}
625+
626+
627+
/*
628+
* Add new ufunc loops (this is somewhat clumsy as of writing it, but should
629+
* get less so with the introduction of public API).
630+
*/
631+
static int
632+
init_ufuncs(void) {
633+
PyArray_DTypeMeta *dtypes[3] = {
634+
&PyArray_SFloatDType, &PyArray_SFloatDType, &PyArray_SFloatDType};
635+
PyType_Slot slots[3] = {{0, NULL}};
636+
PyArrayMethod_Spec spec = {
637+
.nin = 2,
638+
.nout =1,
639+
.dtypes = dtypes,
640+
.slots = slots,
641+
};
642+
spec.name = "sfloat_multiply";
643+
spec.casting = NPY_NO_CASTING;
644+
645+
slots[0].slot = NPY_METH_resolve_descriptors;
646+
slots[0].pfunc = &multiply_sfloats_resolve_descriptors;
647+
slots[1].slot = NPY_METH_strided_loop;
648+
slots[1].pfunc = &multiply_sfloats;
649+
PyBoundArrayMethodObject *bmeth = PyArrayMethod_FromSpec_int(&spec, 0);
650+
if (bmeth == NULL) {
651+
return -1;
652+
}
653+
int res = add_loop("multiply", bmeth);
654+
Py_DECREF(bmeth);
655+
if (res < 0) {
656+
return -1;
657+
}
658+
659+
spec.name = "sfloat_add";
660+
spec.casting = NPY_SAME_KIND_CASTING;
661+
662+
slots[0].slot = NPY_METH_resolve_descriptors;
663+
slots[0].pfunc = &add_sfloats_resolve_descriptors;
664+
slots[1].slot = NPY_METH_strided_loop;
665+
slots[1].pfunc = &add_sfloats;
666+
bmeth = PyArrayMethod_FromSpec_int(&spec, 0);
667+
if (bmeth == NULL) {
668+
return -1;
669+
}
670+
res = add_loop("add", bmeth);
671+
Py_DECREF(bmeth);
672+
if (res < 0) {
673+
return -1;
674+
}
675+
return 0;
676+
}
677+
678+
459679
/*
460680
* Python entry point, exported via `umathmodule.h` and `multiarraymodule.c`.
461681
* TODO: Should be moved when the necessary API is not internal anymore.
@@ -491,6 +711,10 @@ get_sfloat_dtype(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(args))
491711
return NULL;
492712
}
493713

714+
if (init_ufuncs() < 0) {
715+
return NULL;
716+
}
717+
494718
initalized = NPY_TRUE;
495719
return (PyObject *)&PyArray_SFloatDType;
496720
}

numpy/core/src/umath/dispatching.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
6969
* @param ignore_duplicate If 1 and a loop with the same `dtype_tuple` is
7070
* found, the function does nothing.
7171
*/
72-
static int
73-
add_ufunc_loop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
72+
NPY_NO_EXPORT int
73+
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
7474
{
7575
/*
7676
* Validate the info object, this should likely move to to a different
@@ -495,7 +495,7 @@ add_and_return_legacy_wrapping_ufunc_loop(PyUFuncObject *ufunc,
495495
if (info == NULL) {
496496
return NULL;
497497
}
498-
if (add_ufunc_loop(ufunc, info, ignore_duplicate) < 0) {
498+
if (PyUFunc_AddLoop(ufunc, info, ignore_duplicate) < 0) {
499499
Py_DECREF(info);
500500
return NULL;
501501
}

numpy/core/src/umath/dispatching.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
#include "array_method.h"
88

99

10+
NPY_NO_EXPORT int
11+
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate);
12+
1013
NPY_NO_EXPORT PyArrayMethodObject *
1114
promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
1215
PyArrayObject *const ops[],

numpy/core/tests/test_custom_dtypes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,46 @@ def test_sfloat_promotion(self):
9090
# Test an undefined promotion:
9191
with pytest.raises(TypeError):
9292
np.result_type(SF(1.), np.int64)
93+
94+
def test_basic_multiply(self):
95+
a = self._get_array(2.)
96+
b = self._get_array(4.)
97+
98+
res = a * b
99+
# multiplies dtype scaling and content separately:
100+
assert res.dtype.get_scaling() == 8.
101+
expected_view = a.view(np.float64) * b.view(np.float64)
102+
assert_array_equal(res.view(np.float64), expected_view)
103+
104+
def test_basic_addition(self):
105+
a = self._get_array(2.)
106+
b = self._get_array(4.)
107+
108+
res = a + b
109+
# addition uses the type promotion rules for the result:
110+
assert res.dtype == np.result_type(a.dtype, b.dtype)
111+
expected_view = (a.astype(res.dtype).view(np.float64) +
112+
b.astype(res.dtype).view(np.float64))
113+
assert_array_equal(res.view(np.float64), expected_view)
114+
115+
def test_addition_cast_safety(self):
116+
"""The addition method is special for the scaled float, because it
117+
includes the "cast" between different factors, thus cast-safety
118+
is influenced by the implementation.
119+
"""
120+
a = self._get_array(2.)
121+
b = self._get_array(-2.)
122+
c = self._get_array(3.)
123+
124+
# sign change is "equiv":
125+
np.add(a, b, casting="equiv")
126+
with pytest.raises(TypeError):
127+
np.add(a, b, casting="no")
128+
129+
# Different factor is "same_kind" (default) so check that "safe" fails
130+
with pytest.raises(TypeError):
131+
np.add(a, c, casting="safe")
132+
133+
# Check that casting the output fails also (done by the ufunc here)
134+
with pytest.raises(TypeError):
135+
np.add(a, a, out=c, casting="safe")

0 commit comments

Comments
 (0)