Skip to content

Commit 2f95075

Browse files
authored
Merge pull request #39 from ngoldbaum/metadata-ufuncs
Refactor metadatadtype to use ufunc wrapping
2 parents 5b75e4b + 2a703c3 commit 2f95075

File tree

8 files changed

+123
-84
lines changed

8 files changed

+123
-84
lines changed

metadatadtype/meson.build

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ py.install_sources(
4141
py.extension_module(
4242
'_metadatadtype_main',
4343
srcs,
44-
c_args: ['-g', '-O0'],
4544
install: true,
4645
subdir: 'metadatadtype',
4746
include_directories: includes

metadatadtype/metadatadtype/src/dtype.c

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,12 @@ new_metadatadtype_instance(PyObject *metadata)
8181
return new;
8282
}
8383

84-
/*
85-
* This is used to determine the correct dtype to return when operations mix
86-
* dtypes (I think?). For now just return the first one.
87-
*/
88-
static MetadataDTypeObject *
89-
common_instance(MetadataDTypeObject *dtype1, MetadataDTypeObject *dtype2)
84+
PyArray_Descr *
85+
common_instance(MetadataDTypeObject *dtype1,
86+
MetadataDTypeObject *NPY_UNUSED(dtype2))
9087
{
9188
Py_INCREF(dtype1);
92-
return dtype1;
89+
return (PyArray_Descr *)dtype1;
9390
}
9491

9592
static PyArray_DTypeMeta *

metadatadtype/metadatadtype/src/dtype.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,11 @@ new_metadatadtype_instance(PyObject *metadata);
2727
int
2828
init_metadata_dtype(void);
2929

30+
PyArray_Descr *
31+
common_instance(MetadataDTypeObject *dtype1,
32+
MetadataDTypeObject *NPY_UNUSED(dtype2));
33+
34+
// from numpy's dtypemeta.h, not publicly available
35+
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
36+
3037
#endif /*_NPY_DTYPE_H*/

metadatadtype/metadatadtype/src/metadatadtype_main.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ PyInit__metadatadtype_main(void)
5151
goto error;
5252
}
5353

54-
if (init_multiply_ufunc() < 0) {
54+
if (init_ufuncs() < 0) {
5555
goto error;
5656
}
5757

metadatadtype/metadatadtype/src/umath.c

Lines changed: 96 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -12,96 +12,118 @@
1212
#include "umath.h"
1313

1414
static int
15-
metadata_multiply_strided_loop(PyArrayMethod_Context *context,
16-
char *const data[], npy_intp const dimensions[],
17-
npy_intp const strides[], NpyAuxData *auxdata)
15+
translate_given_descrs(int nin, int nout,
16+
PyArray_DTypeMeta *NPY_UNUSED(wrapped_dtypes[]),
17+
PyArray_Descr *given_descrs[],
18+
PyArray_Descr *new_descrs[])
1819
{
19-
npy_intp N = dimensions[0];
20-
char *in1 = data[0], *in2 = data[1];
21-
char *out = data[2];
22-
npy_intp in1_stride = strides[0];
23-
npy_intp in2_stride = strides[1];
24-
npy_intp out_stride = strides[2];
25-
26-
while (N--) {
27-
*(double *)out = *(double *)in1 * *(double *)in2;
28-
in1 += in1_stride;
29-
in2 += in2_stride;
30-
out += out_stride;
20+
for (int i = 0; i < nin + nout; i++) {
21+
if (given_descrs[i] == NULL) {
22+
new_descrs[i] = NULL;
23+
}
24+
else {
25+
if (NPY_DTYPE(given_descrs[i]) == &PyArray_BoolDType) {
26+
new_descrs[i] = PyArray_DescrFromType(NPY_BOOL);
27+
}
28+
else {
29+
new_descrs[i] = PyArray_DescrFromType(NPY_DOUBLE);
30+
}
31+
}
3132
}
3233
return 0;
3334
}
3435

35-
static NPY_CASTING
36-
metadata_multiply_resolve_descriptors(PyObject *self,
37-
PyArray_DTypeMeta *dtypes[],
38-
PyArray_Descr *given_descrs[],
39-
PyArray_Descr *loop_descrs[],
40-
npy_intp *unused)
36+
static int
37+
translate_loop_descrs(int nin, int NPY_UNUSED(nout),
38+
PyArray_DTypeMeta *NPY_UNUSED(new_dtypes[]),
39+
PyArray_Descr *given_descrs[],
40+
PyArray_Descr *original_descrs[],
41+
PyArray_Descr *loop_descrs[])
4142
{
42-
// for now just the take the metadata of the first operand
43-
PyObject *meta1 = ((MetadataDTypeObject *)given_descrs[0])->metadata;
44-
45-
/* Create new DType from the new unit: */
46-
loop_descrs[2] = (PyArray_Descr *)new_metadatadtype_instance(meta1);
47-
if (loop_descrs[2] == NULL) {
43+
if (nin == 2) {
44+
loop_descrs[0] =
45+
common_instance((MetadataDTypeObject *)given_descrs[0],
46+
(MetadataDTypeObject *)given_descrs[1]);
47+
if (loop_descrs[0] == NULL) {
48+
return -1;
49+
}
50+
Py_INCREF(loop_descrs[0]);
51+
loop_descrs[1] = loop_descrs[0];
52+
Py_INCREF(loop_descrs[1]);
53+
if (NPY_DTYPE(original_descrs[2]) == &PyArray_BoolDType) {
54+
loop_descrs[2] = PyArray_DescrFromType(NPY_BOOL);
55+
}
56+
else {
57+
loop_descrs[2] = loop_descrs[0];
58+
}
59+
Py_INCREF(loop_descrs[2]);
60+
}
61+
else if (nin == 1) {
62+
loop_descrs[0] = given_descrs[0];
63+
Py_INCREF(loop_descrs[0]);
64+
if (NPY_DTYPE(original_descrs[1]) == &PyArray_BoolDType) {
65+
loop_descrs[1] = PyArray_DescrFromType(NPY_BOOL);
66+
}
67+
else {
68+
loop_descrs[1] = loop_descrs[0];
69+
}
70+
Py_INCREF(loop_descrs[1]);
71+
}
72+
else {
4873
return -1;
4974
}
50-
/* The other operand units can be used as-is: */
51-
Py_INCREF(given_descrs[0]);
52-
loop_descrs[0] = given_descrs[0];
53-
Py_INCREF(given_descrs[1]);
54-
loop_descrs[1] = given_descrs[1];
55-
56-
return NPY_NO_CASTING;
75+
return 0;
5776
}
5877

59-
/*
60-
* Function that adds our multiply loop to NumPy's multiply ufunc.
61-
*/
62-
int
63-
init_multiply_ufunc(void)
78+
static PyObject *
79+
get_ufunc(const char *ufunc_name)
6480
{
65-
/*
66-
* Get the multiply ufunc:
67-
*/
68-
PyObject *numpy = PyImport_ImportModule("numpy");
69-
if (numpy == NULL) {
70-
return -1;
71-
}
72-
PyObject *multiply = PyObject_GetAttrString(numpy, "multiply");
73-
Py_DECREF(numpy);
74-
if (multiply == NULL) {
75-
return -1;
81+
PyObject *mod = PyImport_ImportModule("numpy");
82+
if (mod == NULL) {
83+
return NULL;
7684
}
85+
PyObject *ufunc = PyObject_GetAttrString(mod, ufunc_name);
86+
Py_DECREF(mod);
7787

78-
/*
79-
* The initializing "wrap up" code from the slides (plus one error check)
80-
*/
81-
static PyArray_DTypeMeta *dtypes[3] = {&MetadataDType, &MetadataDType,
82-
&MetadataDType};
88+
return ufunc;
89+
}
8390

84-
static PyType_Slot slots[] = {
85-
{NPY_METH_resolve_descriptors,
86-
&metadata_multiply_resolve_descriptors},
87-
{NPY_METH_strided_loop, &metadata_multiply_strided_loop},
88-
{0, NULL}};
91+
static int
92+
add_wrapping_loop(const char *ufunc_name, PyArray_DTypeMeta **dtypes,
93+
PyArray_DTypeMeta **wrapped_dtypes)
94+
{
95+
PyObject *ufunc = get_ufunc(ufunc_name);
96+
if (ufunc == NULL) {
97+
return -1;
98+
}
99+
int res = PyUFunc_AddWrappingLoop(ufunc, dtypes, wrapped_dtypes,
100+
&translate_given_descrs,
101+
&translate_loop_descrs);
102+
return res;
103+
}
89104

90-
PyArrayMethod_Spec MultiplySpec = {
91-
.name = "metadata_multiply",
92-
.nin = 2,
93-
.nout = 1,
94-
.dtypes = dtypes,
95-
.slots = slots,
96-
.flags = 0,
97-
.casting = NPY_NO_CASTING,
98-
};
105+
int
106+
init_ufuncs(void)
107+
{
108+
PyArray_DTypeMeta *binary_orig_dtypes[3] = {&MetadataDType, &MetadataDType,
109+
&MetadataDType};
110+
PyArray_DTypeMeta *binary_wrapped_dtypes[3] = {
111+
&PyArray_DoubleDType, &PyArray_DoubleDType, &PyArray_DoubleDType};
112+
if (add_wrapping_loop("multiply", binary_orig_dtypes,
113+
binary_wrapped_dtypes) == -1) {
114+
goto error;
115+
}
99116

100-
/* Register */
101-
if (PyUFunc_AddLoopFromSpec(multiply, &MultiplySpec) < 0) {
102-
Py_DECREF(multiply);
103-
return -1;
117+
PyArray_DTypeMeta *unary_boolean_dtypes[2] = {&MetadataDType,
118+
&PyArray_BoolDType};
119+
PyArray_DTypeMeta *unary_boolean_wrapped_dtypes[2] = {&PyArray_DoubleDType,
120+
&PyArray_BoolDType};
121+
if (add_wrapping_loop("isnan", unary_boolean_dtypes,
122+
unary_boolean_wrapped_dtypes) == -1) {
123+
goto error;
104124
}
105-
Py_DECREF(multiply);
125+
106126
return 0;
127+
error:
128+
return -1;
107129
}

metadatadtype/metadatadtype/src/umath.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
#define _NPY_UFUNC_H
33

44
int
5-
init_multiply_ufunc(void);
5+
init_ufuncs(void);
66

77
#endif /*_NPY_UFUNC_H */

metadatadtype/pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,9 @@ requires-python = ">=3.9.0"
1818
dependencies = [
1919
"numpy",
2020
]
21+
22+
[tool.meson-python.args]
23+
dist = []
24+
setup = ["-Ddebug=true", "-Doptimization=0"]
25+
compile = []
26+
install = []

metadatadtype/tests/test_metadatadtype.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def test_multiplication():
3737
assert str(res) == "[2.0 test 2.0 test 2.0 test]"
3838

3939

40+
def test_isnan():
41+
dtype = MetadataDType("test")
42+
num_scalar = MetadataScalar(1, dtype)
43+
nan_scalar = MetadataScalar(np.nan, dtype)
44+
arr = np.array([num_scalar, nan_scalar, nan_scalar])
45+
np.testing.assert_array_equal(np.isnan(arr), np.array([False, True, True]))
46+
47+
4048
def test_cast_to_different_metadata():
4149
dtype = MetadataDType("test")
4250
scalar = MetadataScalar(1, dtype)

0 commit comments

Comments
 (0)