@@ -49,7 +49,7 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
49
49
50
50
static NPY_CASTING
51
51
string_equal_resolve_descriptors (PyObject * NPY_UNUSED (self ),
52
- PyArray_DTypeMeta * dtypes [],
52
+ PyArray_DTypeMeta * NPY_UNUSED ( dtypes []) ,
53
53
PyArray_Descr * given_descrs [],
54
54
PyArray_Descr * loop_descrs [],
55
55
npy_intp * NPY_UNUSED (view_offset ))
@@ -61,7 +61,42 @@ string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
61
61
62
62
loop_descrs [2 ] = PyArray_DescrFromType (NPY_BOOL ); // cannot fail
63
63
64
- return NPY_SAFE_CASTING ;
64
+ return NPY_NO_CASTING ;
65
+ }
66
+
67
+ static int
68
+ string_isnan_strided_loop (PyArrayMethod_Context * NPY_UNUSED (context ),
69
+ char * const data [], npy_intp const dimensions [],
70
+ npy_intp const strides [],
71
+ NpyAuxData * NPY_UNUSED (auxdata ))
72
+ {
73
+ npy_intp N = dimensions [0 ];
74
+ npy_bool * out = (npy_bool * )data [1 ];
75
+ npy_intp out_stride = strides [1 ];
76
+
77
+ while (N -- ) {
78
+ // we could represent missing data with a null pointer, but
79
+ // should isnan return True in that case?
80
+ * out = (npy_bool )0 ;
81
+
82
+ out += out_stride ;
83
+ }
84
+
85
+ return 0 ;
86
+ }
87
+
88
+ static NPY_CASTING
89
+ string_isnan_resolve_descriptors (PyObject * NPY_UNUSED (self ),
90
+ PyArray_DTypeMeta * NPY_UNUSED (dtypes []),
91
+ PyArray_Descr * given_descrs [],
92
+ PyArray_Descr * loop_descrs [],
93
+ npy_intp * NPY_UNUSED (view_offset ))
94
+ {
95
+ Py_INCREF (given_descrs [0 ]);
96
+ loop_descrs [0 ] = given_descrs [0 ];
97
+ loop_descrs [1 ] = PyArray_DescrFromType (NPY_BOOL ); // cannot fail
98
+
99
+ return NPY_NO_CASTING ;
65
100
}
66
101
67
102
/*
@@ -131,73 +166,70 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
131
166
}
132
167
133
168
int
134
- init_equal_ufunc (PyObject * numpy )
169
+ init_ufunc (PyObject * numpy , const char * ufunc_name , PyArray_DTypeMeta * * dtypes ,
170
+ resolve_descriptors_function * resolve_func ,
171
+ PyArrayMethod_StridedLoop * loop_func , const char * loop_name ,
172
+ int nin , int nout , NPY_CASTING casting , NPY_ARRAYMETHOD_FLAGS flags )
135
173
{
136
- PyObject * equal = PyObject_GetAttrString (numpy , "equal" );
137
- if (equal == NULL ) {
174
+ PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
175
+ if (ufunc == NULL ) {
138
176
return -1 ;
139
177
}
140
178
141
179
/*
142
180
* Initialize spec for equality
143
181
*/
144
- PyArray_DTypeMeta * eq_dtypes [3 ] = {& StringDType , & StringDType ,
145
- & PyArray_BoolDType };
146
-
147
- static PyType_Slot eq_slots [] = {
148
- {NPY_METH_resolve_descriptors , & string_equal_resolve_descriptors },
149
- {NPY_METH_strided_loop , & string_equal_strided_loop },
150
- {0 , NULL }};
151
-
152
- PyArrayMethod_Spec EqualSpec = {
153
- .name = "string_equal" ,
154
- .nin = 2 ,
155
- .nout = 1 ,
156
- .casting = NPY_NO_CASTING ,
157
- .flags = 0 ,
158
- .dtypes = eq_dtypes ,
159
- .slots = eq_slots ,
182
+ PyType_Slot slots [] = {{NPY_METH_resolve_descriptors , resolve_func },
183
+ {NPY_METH_strided_loop , loop_func },
184
+ {0 , NULL }};
185
+
186
+ PyArrayMethod_Spec spec = {
187
+ .name = loop_name ,
188
+ .nin = nin ,
189
+ .nout = nout ,
190
+ .casting = casting ,
191
+ .flags = flags ,
192
+ .dtypes = dtypes ,
193
+ .slots = slots ,
160
194
};
161
195
162
- if (PyUFunc_AddLoopFromSpec (equal , & EqualSpec ) < 0 ) {
163
- Py_DECREF (equal );
196
+ if (PyUFunc_AddLoopFromSpec (ufunc , & spec ) < 0 ) {
197
+ Py_DECREF (ufunc );
164
198
return -1 ;
165
199
}
166
200
167
- /*
168
- * Add promoter to ufunc, ensures operations that mix StringDType and
169
- * UnicodeDType cast the unicode argument to string.
170
- */
201
+ Py_DECREF (ufunc );
202
+ return 0 ;
203
+ }
171
204
172
- PyObject * DTypes [] = {
173
- PyTuple_Pack (3 , & StringDType , & PyArray_UnicodeDType ,
174
- & PyArray_BoolDType ),
175
- PyTuple_Pack (3 , & PyArray_UnicodeDType , & StringDType ,
176
- & PyArray_BoolDType ),
177
- };
205
+ int
206
+ add_promoter (PyObject * numpy , const char * ufunc_name ,
207
+ PyArray_DTypeMeta * * dtypes )
208
+ {
209
+ PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
210
+ if (ufunc == NULL ) {
211
+ return -1 ;
212
+ }
178
213
179
- if ((DTypes [0 ] == NULL ) || (DTypes [1 ] == NULL )) {
180
- Py_DECREF (equal );
214
+ PyObject * DType_tuple = PyTuple_Pack (3 , dtypes [0 ], dtypes [1 ], dtypes [2 ]);
215
+ if (DType_tuple == NULL ) {
216
+ Py_DECREF (ufunc );
181
217
return -1 ;
182
218
}
183
219
184
220
PyObject * promoter_capsule = PyCapsule_New ((void * )& default_ufunc_promoter ,
185
221
"numpy._ufunc_promoter" , NULL );
186
222
187
- for (int i = 0 ; i < 2 ; i ++ ) {
188
- if (PyUFunc_AddPromoter (equal , DTypes [i ], promoter_capsule ) < 0 ) {
189
- Py_DECREF (promoter_capsule );
190
- Py_DECREF (DTypes [0 ]);
191
- Py_DECREF (DTypes [1 ]);
192
- Py_DECREF (equal );
193
- return -1 ;
194
- }
223
+ if (PyUFunc_AddPromoter (ufunc , DType_tuple , promoter_capsule ) < 0 ) {
224
+ Py_DECREF (promoter_capsule );
225
+ Py_DECREF (DType_tuple );
226
+ Py_DECREF (ufunc );
227
+ return -1 ;
195
228
}
196
229
197
230
Py_DECREF (promoter_capsule );
198
- Py_DECREF (DTypes [0 ]);
199
- Py_DECREF (DTypes [1 ]);
200
- Py_DECREF (equal );
231
+ Py_DECREF (DType_tuple );
232
+ Py_DECREF (ufunc );
201
233
202
234
return 0 ;
203
235
}
@@ -210,7 +242,35 @@ init_ufuncs(void)
210
242
return -1 ;
211
243
}
212
244
213
- if (init_equal_ufunc (numpy ) < 0 ) {
245
+ PyArray_DTypeMeta * eq_dtypes [] = {& StringDType , & StringDType ,
246
+ & PyArray_BoolDType };
247
+
248
+ if (init_ufunc (numpy , "equal" , eq_dtypes ,
249
+ & string_equal_resolve_descriptors ,
250
+ & string_equal_strided_loop , "string_equal" , 2 , 1 ,
251
+ NPY_NO_CASTING , 0 ) < 0 ) {
252
+ goto error ;
253
+ }
254
+
255
+ PyArray_DTypeMeta * promoter_dtypes [2 ][3 ] = {
256
+ {& StringDType , & PyArray_UnicodeDType , & PyArray_BoolDType },
257
+ {& PyArray_UnicodeDType , & StringDType , & PyArray_BoolDType },
258
+ };
259
+
260
+ if (add_promoter (numpy , "equal" , promoter_dtypes [0 ]) < 0 ) {
261
+ goto error ;
262
+ }
263
+
264
+ if (add_promoter (numpy , "equal" , promoter_dtypes [1 ]) < 0 ) {
265
+ goto error ;
266
+ }
267
+
268
+ PyArray_DTypeMeta * isnan_dtypes [] = {& StringDType , & PyArray_BoolDType };
269
+
270
+ if (init_ufunc (numpy , "isnan" , isnan_dtypes ,
271
+ & string_isnan_resolve_descriptors ,
272
+ & string_isnan_strided_loop , "string_isnan" , 1 , 1 ,
273
+ NPY_NO_CASTING , 0 ) < 0 ) {
214
274
goto error ;
215
275
}
216
276
0 commit comments