|
12 | 12 | #include "umath.h"
|
13 | 13 |
|
14 | 14 | static int
|
15 |
| -metadata_multiply_strided_loop(PyArrayMethod_Context *context, |
16 |
| - char *const data[], npy_intp const dimensions[], |
17 |
| - npy_intp const strides[], NpyAuxData *auxdata) |
| 15 | +translate_given_descrs(int nin, int nout, |
| 16 | + PyArray_DTypeMeta *NPY_UNUSED(wrapped_dtypes[]), |
| 17 | + PyArray_Descr *given_descrs[], |
| 18 | + PyArray_Descr *new_descrs[]) |
18 | 19 | {
|
19 |
| - npy_intp N = dimensions[0]; |
20 |
| - char *in1 = data[0], *in2 = data[1]; |
21 |
| - char *out = data[2]; |
22 |
| - npy_intp in1_stride = strides[0]; |
23 |
| - npy_intp in2_stride = strides[1]; |
24 |
| - npy_intp out_stride = strides[2]; |
25 |
| - |
26 |
| - while (N--) { |
27 |
| - *(double *)out = *(double *)in1 * *(double *)in2; |
28 |
| - in1 += in1_stride; |
29 |
| - in2 += in2_stride; |
30 |
| - out += out_stride; |
| 20 | + for (int i = 0; i < nin + nout; i++) { |
| 21 | + if (given_descrs[i] == NULL) { |
| 22 | + new_descrs[i] = NULL; |
| 23 | + } |
| 24 | + else { |
| 25 | + if (NPY_DTYPE(given_descrs[i]) == &PyArray_BoolDType) { |
| 26 | + new_descrs[i] = PyArray_DescrFromType(NPY_BOOL); |
| 27 | + } |
| 28 | + else { |
| 29 | + new_descrs[i] = PyArray_DescrFromType(NPY_DOUBLE); |
| 30 | + } |
| 31 | + } |
31 | 32 | }
|
32 | 33 | return 0;
|
33 | 34 | }
|
34 | 35 |
|
35 |
| -static NPY_CASTING |
36 |
| -metadata_multiply_resolve_descriptors(PyObject *self, |
37 |
| - PyArray_DTypeMeta *dtypes[], |
38 |
| - PyArray_Descr *given_descrs[], |
39 |
| - PyArray_Descr *loop_descrs[], |
40 |
| - npy_intp *unused) |
| 36 | +static int |
| 37 | +translate_loop_descrs(int nin, int NPY_UNUSED(nout), |
| 38 | + PyArray_DTypeMeta *NPY_UNUSED(new_dtypes[]), |
| 39 | + PyArray_Descr *given_descrs[], |
| 40 | + PyArray_Descr *original_descrs[], |
| 41 | + PyArray_Descr *loop_descrs[]) |
41 | 42 | {
|
42 |
| - // for now just the take the metadata of the first operand |
43 |
| - PyObject *meta1 = ((MetadataDTypeObject *)given_descrs[0])->metadata; |
44 |
| - |
45 |
| - /* Create new DType from the new unit: */ |
46 |
| - loop_descrs[2] = (PyArray_Descr *)new_metadatadtype_instance(meta1); |
47 |
| - if (loop_descrs[2] == NULL) { |
| 43 | + if (nin == 2) { |
| 44 | + loop_descrs[0] = |
| 45 | + common_instance((MetadataDTypeObject *)given_descrs[0], |
| 46 | + (MetadataDTypeObject *)given_descrs[1]); |
| 47 | + if (loop_descrs[0] == NULL) { |
| 48 | + return -1; |
| 49 | + } |
| 50 | + Py_INCREF(loop_descrs[0]); |
| 51 | + loop_descrs[1] = loop_descrs[0]; |
| 52 | + Py_INCREF(loop_descrs[1]); |
| 53 | + if (NPY_DTYPE(original_descrs[2]) == &PyArray_BoolDType) { |
| 54 | + loop_descrs[2] = PyArray_DescrFromType(NPY_BOOL); |
| 55 | + } |
| 56 | + else { |
| 57 | + loop_descrs[2] = loop_descrs[0]; |
| 58 | + } |
| 59 | + Py_INCREF(loop_descrs[2]); |
| 60 | + } |
| 61 | + else if (nin == 1) { |
| 62 | + loop_descrs[0] = given_descrs[0]; |
| 63 | + Py_INCREF(loop_descrs[0]); |
| 64 | + if (NPY_DTYPE(original_descrs[1]) == &PyArray_BoolDType) { |
| 65 | + loop_descrs[1] = PyArray_DescrFromType(NPY_BOOL); |
| 66 | + } |
| 67 | + else { |
| 68 | + loop_descrs[1] = loop_descrs[0]; |
| 69 | + } |
| 70 | + Py_INCREF(loop_descrs[1]); |
| 71 | + } |
| 72 | + else { |
48 | 73 | return -1;
|
49 | 74 | }
|
50 |
| - /* The other operand units can be used as-is: */ |
51 |
| - Py_INCREF(given_descrs[0]); |
52 |
| - loop_descrs[0] = given_descrs[0]; |
53 |
| - Py_INCREF(given_descrs[1]); |
54 |
| - loop_descrs[1] = given_descrs[1]; |
55 |
| - |
56 |
| - return NPY_NO_CASTING; |
| 75 | + return 0; |
57 | 76 | }
|
58 | 77 |
|
59 |
| -/* |
60 |
| - * Function that adds our multiply loop to NumPy's multiply ufunc. |
61 |
| - */ |
62 |
| -int |
63 |
| -init_multiply_ufunc(void) |
| 78 | +static PyObject * |
| 79 | +get_ufunc(const char *ufunc_name) |
64 | 80 | {
|
65 |
| - /* |
66 |
| - * Get the multiply ufunc: |
67 |
| - */ |
68 |
| - PyObject *numpy = PyImport_ImportModule("numpy"); |
69 |
| - if (numpy == NULL) { |
70 |
| - return -1; |
71 |
| - } |
72 |
| - PyObject *multiply = PyObject_GetAttrString(numpy, "multiply"); |
73 |
| - Py_DECREF(numpy); |
74 |
| - if (multiply == NULL) { |
75 |
| - return -1; |
| 81 | + PyObject *mod = PyImport_ImportModule("numpy"); |
| 82 | + if (mod == NULL) { |
| 83 | + return NULL; |
76 | 84 | }
|
| 85 | + PyObject *ufunc = PyObject_GetAttrString(mod, ufunc_name); |
| 86 | + Py_DECREF(mod); |
77 | 87 |
|
78 |
| - /* |
79 |
| - * The initializing "wrap up" code from the slides (plus one error check) |
80 |
| - */ |
81 |
| - static PyArray_DTypeMeta *dtypes[3] = {&MetadataDType, &MetadataDType, |
82 |
| - &MetadataDType}; |
| 88 | + return ufunc; |
| 89 | +} |
83 | 90 |
|
84 |
| - static PyType_Slot slots[] = { |
85 |
| - {NPY_METH_resolve_descriptors, |
86 |
| - &metadata_multiply_resolve_descriptors}, |
87 |
| - {NPY_METH_strided_loop, &metadata_multiply_strided_loop}, |
88 |
| - {0, NULL}}; |
| 91 | +static int |
| 92 | +add_wrapping_loop(const char *ufunc_name, PyArray_DTypeMeta **dtypes, |
| 93 | + PyArray_DTypeMeta **wrapped_dtypes) |
| 94 | +{ |
| 95 | + PyObject *ufunc = get_ufunc(ufunc_name); |
| 96 | + if (ufunc == NULL) { |
| 97 | + return -1; |
| 98 | + } |
| 99 | + int res = PyUFunc_AddWrappingLoop(ufunc, dtypes, wrapped_dtypes, |
| 100 | + &translate_given_descrs, |
| 101 | + &translate_loop_descrs); |
| 102 | + return res; |
| 103 | +} |
89 | 104 |
|
90 |
| - PyArrayMethod_Spec MultiplySpec = { |
91 |
| - .name = "metadata_multiply", |
92 |
| - .nin = 2, |
93 |
| - .nout = 1, |
94 |
| - .dtypes = dtypes, |
95 |
| - .slots = slots, |
96 |
| - .flags = 0, |
97 |
| - .casting = NPY_NO_CASTING, |
98 |
| - }; |
| 105 | +int |
| 106 | +init_ufuncs(void) |
| 107 | +{ |
| 108 | + PyArray_DTypeMeta *binary_orig_dtypes[3] = {&MetadataDType, &MetadataDType, |
| 109 | + &MetadataDType}; |
| 110 | + PyArray_DTypeMeta *binary_wrapped_dtypes[3] = { |
| 111 | + &PyArray_DoubleDType, &PyArray_DoubleDType, &PyArray_DoubleDType}; |
| 112 | + if (add_wrapping_loop("multiply", binary_orig_dtypes, |
| 113 | + binary_wrapped_dtypes) == -1) { |
| 114 | + goto error; |
| 115 | + } |
99 | 116 |
|
100 |
| - /* Register */ |
101 |
| - if (PyUFunc_AddLoopFromSpec(multiply, &MultiplySpec) < 0) { |
102 |
| - Py_DECREF(multiply); |
103 |
| - return -1; |
| 117 | + PyArray_DTypeMeta *unary_boolean_dtypes[2] = {&MetadataDType, |
| 118 | + &PyArray_BoolDType}; |
| 119 | + PyArray_DTypeMeta *unary_boolean_wrapped_dtypes[2] = {&PyArray_DoubleDType, |
| 120 | + &PyArray_BoolDType}; |
| 121 | + if (add_wrapping_loop("isnan", unary_boolean_dtypes, |
| 122 | + unary_boolean_wrapped_dtypes) == -1) { |
| 123 | + goto error; |
104 | 124 | }
|
105 |
| - Py_DECREF(multiply); |
| 125 | + |
106 | 126 | return 0;
|
| 127 | +error: |
| 128 | + return -1; |
107 | 129 | }
|
0 commit comments