Skip to content

Commit cb5d0d3

Browse files
committed
simplify promoter setup by reusing capsule
1 parent 2646d7f commit cb5d0d3

File tree

1 file changed

+25
-35
lines changed

1 file changed

+25
-35
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ init_equal_ufunc(PyObject *numpy)
135135
}
136136

137137
/*
138-
* Initialize spec for equality
138+
* Initialize spec for equality
139139
*/
140140
PyArray_DTypeMeta *eq_dtypes[3] = {&StringDType, &StringDType,
141141
&PyArray_BoolDType};
@@ -161,50 +161,40 @@ init_equal_ufunc(PyObject *numpy)
161161
}
162162

163163
/*
164-
* This might interfere with NumPy at this time.
164+
* Add promoter to ufunc, ensures operations that mix StringDType and
165+
* UnicodeDType cast the unicode argument to string.
165166
*/
166-
PyObject *promoter_capsule1 = PyCapsule_New(
167-
(void *)&default_ufunc_promoter, "numpy._ufunc_promoter", NULL);
168-
if (promoter_capsule1 == NULL) {
169-
return -1;
170-
}
171167

172-
PyObject *DTypes1 = PyTuple_Pack(3, &StringDType, &PyArray_UnicodeDType,
173-
&PyArrayDescr_Type);
174-
if (DTypes1 == 0) {
175-
Py_DECREF(promoter_capsule1);
176-
return -1;
177-
}
168+
PyObject *DTypes[] = {
169+
PyTuple_Pack(3, &StringDType, &PyArray_UnicodeDType,
170+
&PyArray_BoolDType),
171+
PyTuple_Pack(3, &PyArray_UnicodeDType, &StringDType,
172+
&PyArray_BoolDType),
173+
};
178174

179-
if (PyUFunc_AddPromoter(equal, DTypes1, promoter_capsule1) < 0) {
180-
Py_DECREF(promoter_capsule1);
181-
Py_DECREF(DTypes1);
175+
if ((DTypes[0] == NULL) || (DTypes[1] == NULL)) {
176+
Py_DECREF(equal);
182177
return -1;
183178
}
184-
Py_DECREF(promoter_capsule1);
185-
Py_DECREF(DTypes1);
186179

187-
PyObject *promoter_capsule2 = PyCapsule_New(
188-
(void *)&default_ufunc_promoter, "numpy._ufunc_promoter", NULL);
189-
if (promoter_capsule2 == NULL) {
190-
return -1;
191-
}
192-
PyObject *DTypes2 = PyTuple_Pack(3, &PyArray_UnicodeDType, &StringDType,
193-
&PyArrayDescr_Type);
194-
if (DTypes2 == 0) {
195-
Py_DECREF(promoter_capsule2);
196-
return -1;
197-
}
180+
PyObject *promoter_capsule = PyCapsule_New((void *)&default_ufunc_promoter,
181+
"numpy._ufunc_promoter", NULL);
198182

199-
if (PyUFunc_AddPromoter(equal, DTypes2, promoter_capsule2) < 0) {
200-
Py_DECREF(promoter_capsule2);
201-
Py_DECREF(DTypes2);
202-
return -1;
183+
for (int i = 0; i < 2; i++) {
184+
if (PyUFunc_AddPromoter(equal, DTypes[i], promoter_capsule) < 0) {
185+
Py_DECREF(promoter_capsule);
186+
Py_DECREF(DTypes[0]);
187+
Py_DECREF(DTypes[1]);
188+
Py_DECREF(equal);
189+
return -1;
190+
}
203191
}
204-
Py_DECREF(promoter_capsule2);
205-
Py_DECREF(DTypes2);
206192

193+
Py_DECREF(promoter_capsule);
194+
Py_DECREF(DTypes[0]);
195+
Py_DECREF(DTypes[1]);
207196
Py_DECREF(equal);
197+
208198
return 0;
209199
}
210200

0 commit comments

Comments
 (0)