Skip to content

Commit 9d4115e

Browse files
committed
TST: Add basic tests for custom DType (scaled float) ufuncs
This does not actually test a lot of new paths, since the normal ufuncs do cover most of the typical paths. One thing it does test is that the ufunc implementation (as in ArrayMethod) can customize the cast-safety. This is used here since the parametric DType considers different scalings as same-kind casts and not safe/equiv casts.
1 parent 2c1a34d commit 9d4115e

File tree

5 files changed

+279
-7
lines changed

5 files changed

+279
-7
lines changed

numpy/core/src/multiarray/array_method.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,10 +833,12 @@ generic_masked_strided_loop(PyArrayMethod_Context *context,
833833

834834

835835
/*
836-
* Identical to the `get_loop` functions and wraps it. This adds support
837-
* to a boolean mask being passed in as a last, additional, operand.
838-
* The wrapped loop will only be called for unmasked elements.
839-
* (Does not support `move_references` or inner dimensions!)
836+
* Fetches a strided-loop function that supports a boolean mask as additional
837+
* (last) operand to the strided-loop. It is otherwise largely identical to
838+
* the `get_loop` method which it wraps.
839+
* This is the core implementation for the ufunc `where=...` keyword argument.
840+
*
841+
* NOTE: This function does not support `move_references` or inner dimensions.
840842
*/
841843
NPY_NO_EXPORT int
842844
PyArrayMethod_GetMaskedStridedLoop(

numpy/core/src/umath/_scaled_float_dtype.c

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "numpy/npy_math.h"
2424
#include "convert_datatype.h"
2525
#include "dtypemeta.h"
26+
#include "dispatching.h"
2627

2728

2829
typedef struct {
@@ -456,6 +457,225 @@ init_casts(void)
456457
}
457458

458459

460+
/*
461+
* We also wish to test very simple ufunc functionality. So create two
462+
* ufunc loops:
463+
* 1. Multiplication, which can multiply the factors and work with that.
464+
* 2. Addition, which needs to use the common instance, and runs into
465+
* cast safety subtleties since we will implement it without an additional
466+
* cast.
467+
*
468+
* NOTE: When first writing this, promotion did not exist for new-style loops,
469+
* if it exists, we could use promotion to implement double * sfloat.
470+
*/
471+
static int
472+
multiply_sfloats(PyArrayMethod_Context *NPY_UNUSED(context),
473+
char *const data[], npy_intp const dimensions[],
474+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
475+
{
476+
npy_intp N = dimensions[0];
477+
char *in1 = data[0];
478+
char *in2 = data[1];
479+
char *out = data[2];
480+
for (npy_intp i = 0; i < N; i++) {
481+
*(double *)out = *(double *)in1 * *(double *)in2;
482+
in1 += strides[0];
483+
in2 += strides[1];
484+
out += strides[2];
485+
}
486+
return 0;
487+
}
488+
489+
490+
static NPY_CASTING
491+
multiply_sfloats_resolve_descriptors(
492+
PyArrayMethodObject *NPY_UNUSED(self),
493+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[3]),
494+
PyArray_Descr *given_descrs[3],
495+
PyArray_Descr *loop_descrs[3])
496+
{
497+
/*
498+
* Multiply the scaling for the result. If the result was passed in we
499+
* simply ignore it and let the casting machinery fix it up here.
500+
*/
501+
double factor = ((PyArray_SFloatDescr *)given_descrs[1])->scaling;
502+
loop_descrs[2] = sfloat_scaled_copy(
503+
(PyArray_SFloatDescr *)given_descrs[0], factor);
504+
if (loop_descrs[2] == 0) {
505+
return -1;
506+
}
507+
Py_INCREF(given_descrs[0]);
508+
loop_descrs[0] = given_descrs[0];
509+
Py_INCREF(given_descrs[1]);
510+
loop_descrs[1] = given_descrs[1];
511+
return NPY_NO_CASTING;
512+
}
513+
514+
515+
/*
516+
* Unlike the multiplication implementation above, this loops deals with
517+
* scaling (casting) internally. This allows to test some different paths.
518+
*/
519+
static int
520+
add_sfloats(PyArrayMethod_Context *context,
521+
char *const data[], npy_intp const dimensions[],
522+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
523+
{
524+
double fin1 = ((PyArray_SFloatDescr *)context->descriptors[0])->scaling;
525+
double fin2 = ((PyArray_SFloatDescr *)context->descriptors[1])->scaling;
526+
double fout = ((PyArray_SFloatDescr *)context->descriptors[2])->scaling;
527+
528+
double fact1 = fin1 / fout;
529+
double fact2 = fin2 / fout;
530+
if (check_factor(fact1) < 0) {
531+
return -1;
532+
}
533+
if (check_factor(fact2) < 0) {
534+
return -1;
535+
}
536+
537+
npy_intp N = dimensions[0];
538+
char *in1 = data[0];
539+
char *in2 = data[1];
540+
char *out = data[2];
541+
for (npy_intp i = 0; i < N; i++) {
542+
*(double *)out = (*(double *)in1 * fact1) + (*(double *)in2 * fact2);
543+
in1 += strides[0];
544+
in2 += strides[1];
545+
out += strides[2];
546+
}
547+
return 0;
548+
}
549+
550+
551+
static NPY_CASTING
552+
add_sfloats_resolve_descriptors(
553+
PyArrayMethodObject *NPY_UNUSED(self),
554+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[3]),
555+
PyArray_Descr *given_descrs[3],
556+
PyArray_Descr *loop_descrs[3])
557+
{
558+
/*
559+
* Here we accept an output descriptor (the inner loop can deal with it),
560+
* if none is given, we use the "common instance":
561+
*/
562+
if (given_descrs[2] == NULL) {
563+
loop_descrs[2] = sfloat_common_instance(
564+
given_descrs[0], given_descrs[1]);
565+
if (loop_descrs[2] == 0) {
566+
return -1;
567+
}
568+
}
569+
else {
570+
Py_INCREF(given_descrs[2]);
571+
loop_descrs[2] = given_descrs[2];
572+
}
573+
Py_INCREF(given_descrs[0]);
574+
loop_descrs[0] = given_descrs[0];
575+
Py_INCREF(given_descrs[1]);
576+
loop_descrs[1] = given_descrs[1];
577+
578+
/* If the factors mismatch, we do implicit casting inside the ufunc! */
579+
double fin1 = ((PyArray_SFloatDescr *)loop_descrs[0])->scaling;
580+
double fin2 = ((PyArray_SFloatDescr *)loop_descrs[1])->scaling;
581+
double fout = ((PyArray_SFloatDescr *)loop_descrs[2])->scaling;
582+
583+
if (fin1 == fout && fin2 == fout) {
584+
return NPY_NO_CASTING;
585+
}
586+
if (npy_fabs(fin1) == npy_fabs(fout) && npy_fabs(fin2) == npy_fabs(fout)) {
587+
return NPY_EQUIV_CASTING;
588+
}
589+
return NPY_SAME_KIND_CASTING;
590+
}
591+
592+
593+
static int
594+
add_loop(const char *ufunc_name, PyBoundArrayMethodObject *bmeth)
595+
{
596+
PyObject *mod = PyImport_ImportModule("numpy");
597+
if (mod == NULL) {
598+
return -1;
599+
}
600+
PyObject *ufunc = PyObject_GetAttrString(mod, ufunc_name);
601+
Py_DECREF(mod);
602+
if (!PyObject_TypeCheck(ufunc, &PyUFunc_Type)) {
603+
Py_DECREF(ufunc);
604+
PyErr_Format(PyExc_TypeError,
605+
"numpy.%s was not a ufunc!", ufunc_name);
606+
return -1;
607+
}
608+
PyObject *dtype_tup = PyArray_TupleFromItems(
609+
3, (PyObject **)bmeth->dtypes, 0);
610+
if (dtype_tup == NULL) {
611+
Py_DECREF(ufunc);
612+
return -1;
613+
}
614+
PyObject *info = PyTuple_Pack(2, dtype_tup, bmeth->method);
615+
Py_DECREF(dtype_tup);
616+
if (info == NULL) {
617+
Py_DECREF(ufunc);
618+
return -1;
619+
}
620+
int res = PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0);
621+
Py_DECREF(ufunc);
622+
Py_DECREF(info);
623+
return res;
624+
}
625+
626+
627+
/*
628+
* Add new ufunc loops (this is somewhat clumsy as of writing it, but should
629+
* get less so with the introduction of public API).
630+
*/
631+
static int
632+
init_ufuncs(void) {
633+
PyArray_DTypeMeta *dtypes[3] = {
634+
&PyArray_SFloatDType, &PyArray_SFloatDType, &PyArray_SFloatDType};
635+
PyType_Slot slots[3] = {{0, NULL}};
636+
PyArrayMethod_Spec spec = {
637+
.nin = 2,
638+
.nout =1,
639+
.dtypes = dtypes,
640+
.slots = slots,
641+
};
642+
spec.name = "sfloat_multiply";
643+
spec.casting = NPY_NO_CASTING;
644+
645+
slots[0].slot = NPY_METH_resolve_descriptors;
646+
slots[0].pfunc = &multiply_sfloats_resolve_descriptors;
647+
slots[1].slot = NPY_METH_strided_loop;
648+
slots[1].pfunc = &multiply_sfloats;
649+
PyBoundArrayMethodObject *bmeth = PyArrayMethod_FromSpec_int(&spec, 0);
650+
if (bmeth == NULL) {
651+
return -1;
652+
}
653+
int res = add_loop("multiply", bmeth);
654+
Py_DECREF(bmeth);
655+
if (res < 0) {
656+
return -1;
657+
}
658+
659+
spec.name = "sfloat_add";
660+
spec.casting = NPY_SAME_KIND_CASTING;
661+
662+
slots[0].slot = NPY_METH_resolve_descriptors;
663+
slots[0].pfunc = &add_sfloats_resolve_descriptors;
664+
slots[1].slot = NPY_METH_strided_loop;
665+
slots[1].pfunc = &add_sfloats;
666+
bmeth = PyArrayMethod_FromSpec_int(&spec, 0);
667+
if (bmeth == NULL) {
668+
return -1;
669+
}
670+
res = add_loop("add", bmeth);
671+
Py_DECREF(bmeth);
672+
if (res < 0) {
673+
return -1;
674+
}
675+
return 0;
676+
}
677+
678+
459679
/*
460680
* Python entry point, exported via `umathmodule.h` and `multiarraymodule.c`.
461681
* TODO: Should be moved when the necessary API is not internal anymore.
@@ -491,6 +711,10 @@ get_sfloat_dtype(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(args))
491711
return NULL;
492712
}
493713

714+
if (init_ufuncs() < 0) {
715+
return NULL;
716+
}
717+
494718
initalized = NPY_TRUE;
495719
return (PyObject *)&PyArray_SFloatDType;
496720
}

numpy/core/src/umath/dispatching.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
6969
* @param ignore_duplicate If 1 and a loop with the same `dtype_tuple` is
7070
* found, the function does nothing.
7171
*/
72-
static int
73-
add_ufunc_loop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
72+
NPY_NO_EXPORT int
73+
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
7474
{
7575
/*
7676
* Validate the info object, this should likely move to to a different
@@ -495,7 +495,7 @@ add_and_return_legacy_wrapping_ufunc_loop(PyUFuncObject *ufunc,
495495
if (info == NULL) {
496496
return NULL;
497497
}
498-
if (add_ufunc_loop(ufunc, info, ignore_duplicate) < 0) {
498+
if (PyUFunc_AddLoop(ufunc, info, ignore_duplicate) < 0) {
499499
Py_DECREF(info);
500500
return NULL;
501501
}

numpy/core/src/umath/dispatching.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
#include "array_method.h"
88

99

10+
NPY_NO_EXPORT int
11+
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate);
12+
1013
NPY_NO_EXPORT PyArrayMethodObject *
1114
promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
1215
PyArrayObject *const ops[],

numpy/core/tests/test_custom_dtypes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,46 @@ def test_sfloat_promotion(self):
9090
# Test an undefined promotion:
9191
with pytest.raises(TypeError):
9292
np.result_type(SF(1.), np.int64)
93+
94+
def test_basic_multiply(self):
95+
a = self._get_array(2.)
96+
b = self._get_array(4.)
97+
98+
res = a * b
99+
# multiplies dtype scaling and content separately:
100+
assert res.dtype.get_scaling() == 8.
101+
expected_view = a.view(np.float64) * b.view(np.float64)
102+
assert_array_equal(res.view(np.float64), expected_view)
103+
104+
def test_basic_addition(self):
105+
a = self._get_array(2.)
106+
b = self._get_array(4.)
107+
108+
res = a + b
109+
# addition uses the type promotion rules for the result:
110+
assert res.dtype == np.result_type(a.dtype, b.dtype)
111+
expected_view = (a.astype(res.dtype).view(np.float64) +
112+
b.astype(res.dtype).view(np.float64))
113+
assert_array_equal(res.view(np.float64), expected_view)
114+
115+
def test_addition_cast_safety(self):
116+
"""The addition method is special for the scaled float, because it
117+
includes the "cast" between different factors, thus cast-safety
118+
is influenced by the implementation.
119+
"""
120+
a = self._get_array(2.)
121+
b = self._get_array(-2.)
122+
c = self._get_array(3.)
123+
124+
# sign change is "equiv":
125+
np.add(a, b, casting="equiv")
126+
with pytest.raises(TypeError):
127+
np.add(a, b, casting="no")
128+
129+
# Different factor is "same_kind" (default) so check that "safe" fails
130+
with pytest.raises(TypeError):
131+
np.add(a, c, casting="safe")
132+
133+
# Check that casting the output fails also (done by the ufunc here)
134+
with pytest.raises(TypeError):
135+
np.add(a, a, out=c, casting="safe")

0 commit comments

Comments
 (0)