Skip to content

Commit 016623e

Browse files
committed
refactor ufunc setup and add isnan loop
1 parent e9449e2 commit 016623e

File tree

2 files changed

+114
-47
lines changed

2 files changed

+114
-47
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 107 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
4949

5050
static NPY_CASTING
5151
string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
52-
PyArray_DTypeMeta *dtypes[],
52+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
5353
PyArray_Descr *given_descrs[],
5454
PyArray_Descr *loop_descrs[],
5555
npy_intp *NPY_UNUSED(view_offset))
@@ -61,7 +61,42 @@ string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
6161

6262
loop_descrs[2] = PyArray_DescrFromType(NPY_BOOL); // cannot fail
6363

64-
return NPY_SAFE_CASTING;
64+
return NPY_NO_CASTING;
65+
}
66+
67+
static int
68+
string_isnan_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
69+
char *const data[], npy_intp const dimensions[],
70+
npy_intp const strides[],
71+
NpyAuxData *NPY_UNUSED(auxdata))
72+
{
73+
npy_intp N = dimensions[0];
74+
npy_bool *out = (npy_bool *)data[1];
75+
npy_intp out_stride = strides[1];
76+
77+
while (N--) {
78+
// we could represent missing data with a null pointer, but
79+
// should isnan return True in that case?
80+
*out = (npy_bool)0;
81+
82+
out += out_stride;
83+
}
84+
85+
return 0;
86+
}
87+
88+
static NPY_CASTING
89+
string_isnan_resolve_descriptors(PyObject *NPY_UNUSED(self),
90+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
91+
PyArray_Descr *given_descrs[],
92+
PyArray_Descr *loop_descrs[],
93+
npy_intp *NPY_UNUSED(view_offset))
94+
{
95+
Py_INCREF(given_descrs[0]);
96+
loop_descrs[0] = given_descrs[0];
97+
loop_descrs[1] = PyArray_DescrFromType(NPY_BOOL); // cannot fail
98+
99+
return NPY_NO_CASTING;
65100
}
66101

67102
/*
@@ -131,73 +166,70 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
131166
}
132167

133168
int
134-
init_equal_ufunc(PyObject *numpy)
169+
init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
170+
resolve_descriptors_function *resolve_func,
171+
PyArrayMethod_StridedLoop *loop_func, const char *loop_name,
172+
int nin, int nout, NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags)
135173
{
136-
PyObject *equal = PyObject_GetAttrString(numpy, "equal");
137-
if (equal == NULL) {
174+
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
175+
if (ufunc == NULL) {
138176
return -1;
139177
}
140178

141179
/*
142180
* Initialize spec for equality
143181
*/
144-
PyArray_DTypeMeta *eq_dtypes[3] = {&StringDType, &StringDType,
145-
&PyArray_BoolDType};
146-
147-
static PyType_Slot eq_slots[] = {
148-
{NPY_METH_resolve_descriptors, &string_equal_resolve_descriptors},
149-
{NPY_METH_strided_loop, &string_equal_strided_loop},
150-
{0, NULL}};
151-
152-
PyArrayMethod_Spec EqualSpec = {
153-
.name = "string_equal",
154-
.nin = 2,
155-
.nout = 1,
156-
.casting = NPY_NO_CASTING,
157-
.flags = 0,
158-
.dtypes = eq_dtypes,
159-
.slots = eq_slots,
182+
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func},
183+
{NPY_METH_strided_loop, loop_func},
184+
{0, NULL}};
185+
186+
PyArrayMethod_Spec spec = {
187+
.name = loop_name,
188+
.nin = nin,
189+
.nout = nout,
190+
.casting = casting,
191+
.flags = flags,
192+
.dtypes = dtypes,
193+
.slots = slots,
160194
};
161195

162-
if (PyUFunc_AddLoopFromSpec(equal, &EqualSpec) < 0) {
163-
Py_DECREF(equal);
196+
if (PyUFunc_AddLoopFromSpec(ufunc, &spec) < 0) {
197+
Py_DECREF(ufunc);
164198
return -1;
165199
}
166200

167-
/*
168-
* Add promoter to ufunc, ensures operations that mix StringDType and
169-
* UnicodeDType cast the unicode argument to string.
170-
*/
201+
Py_DECREF(ufunc);
202+
return 0;
203+
}
171204

172-
PyObject *DTypes[] = {
173-
PyTuple_Pack(3, &StringDType, &PyArray_UnicodeDType,
174-
&PyArray_BoolDType),
175-
PyTuple_Pack(3, &PyArray_UnicodeDType, &StringDType,
176-
&PyArray_BoolDType),
177-
};
205+
int
206+
add_promoter(PyObject *numpy, const char *ufunc_name,
207+
PyArray_DTypeMeta **dtypes)
208+
{
209+
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
210+
if (ufunc == NULL) {
211+
return -1;
212+
}
178213

179-
if ((DTypes[0] == NULL) || (DTypes[1] == NULL)) {
180-
Py_DECREF(equal);
214+
PyObject *DType_tuple = PyTuple_Pack(3, dtypes[0], dtypes[1], dtypes[2]);
215+
if (DType_tuple == NULL) {
216+
Py_DECREF(ufunc);
181217
return -1;
182218
}
183219

184220
PyObject *promoter_capsule = PyCapsule_New((void *)&default_ufunc_promoter,
185221
"numpy._ufunc_promoter", NULL);
186222

187-
for (int i = 0; i < 2; i++) {
188-
if (PyUFunc_AddPromoter(equal, DTypes[i], promoter_capsule) < 0) {
189-
Py_DECREF(promoter_capsule);
190-
Py_DECREF(DTypes[0]);
191-
Py_DECREF(DTypes[1]);
192-
Py_DECREF(equal);
193-
return -1;
194-
}
223+
if (PyUFunc_AddPromoter(ufunc, DType_tuple, promoter_capsule) < 0) {
224+
Py_DECREF(promoter_capsule);
225+
Py_DECREF(DType_tuple);
226+
Py_DECREF(ufunc);
227+
return -1;
195228
}
196229

197230
Py_DECREF(promoter_capsule);
198-
Py_DECREF(DTypes[0]);
199-
Py_DECREF(DTypes[1]);
200-
Py_DECREF(equal);
231+
Py_DECREF(DType_tuple);
232+
Py_DECREF(ufunc);
201233

202234
return 0;
203235
}
@@ -210,7 +242,35 @@ init_ufuncs(void)
210242
return -1;
211243
}
212244

213-
if (init_equal_ufunc(numpy) < 0) {
245+
PyArray_DTypeMeta *eq_dtypes[] = {&StringDType, &StringDType,
246+
&PyArray_BoolDType};
247+
248+
if (init_ufunc(numpy, "equal", eq_dtypes,
249+
&string_equal_resolve_descriptors,
250+
&string_equal_strided_loop, "string_equal", 2, 1,
251+
NPY_NO_CASTING, 0) < 0) {
252+
goto error;
253+
}
254+
255+
PyArray_DTypeMeta *promoter_dtypes[2][3] = {
256+
{&StringDType, &PyArray_UnicodeDType, &PyArray_BoolDType},
257+
{&PyArray_UnicodeDType, &StringDType, &PyArray_BoolDType},
258+
};
259+
260+
if (add_promoter(numpy, "equal", promoter_dtypes[0]) < 0) {
261+
goto error;
262+
}
263+
264+
if (add_promoter(numpy, "equal", promoter_dtypes[1]) < 0) {
265+
goto error;
266+
}
267+
268+
PyArray_DTypeMeta *isnan_dtypes[] = {&StringDType, &PyArray_BoolDType};
269+
270+
if (init_ufunc(numpy, "isnan", isnan_dtypes,
271+
&string_isnan_resolve_descriptors,
272+
&string_isnan_strided_loop, "string_isnan", 1, 1,
273+
NPY_NO_CASTING, 0) < 0) {
214274
goto error;
215275
}
216276

stringdtype/tests/test_stringdtype.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,10 @@ def test_equality_promotion(string_list):
104104

105105
np.testing.assert_array_equal(sarr, uarr)
106106
np.testing.assert_array_equal(uarr, sarr)
107+
108+
109+
def test_isnan(string_list):
110+
sarr = np.array(string_list, dtype=StringDType())
111+
np.testing.assert_array_equal(
112+
np.isnan(sarr), np.zeros_like(sarr, dtype=np.bool_)
113+
)

0 commit comments

Comments
 (0)