Skip to content

Commit 1ce658b

Browse files
committed
WIP: adding promoter for equals and not_equals ufuncs
1 parent dc0d884 commit 1ce658b

File tree

2 files changed

+147
-54
lines changed

2 files changed

+147
-54
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 140 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,47 @@ string_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
164164
return 0;
165165
}
166166

167+
static int
168+
string_not_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
169+
char *const data[], npy_intp const dimensions[],
170+
npy_intp const strides[],
171+
NpyAuxData *NPY_UNUSED(auxdata))
172+
{
173+
npy_intp N = dimensions[0];
174+
char *in1 = data[0];
175+
char *in2 = data[1];
176+
npy_bool *out = (npy_bool *)data[2];
177+
npy_intp in1_stride = strides[0];
178+
npy_intp in2_stride = strides[1];
179+
npy_intp out_stride = strides[2];
180+
181+
ss *s1 = NULL, *s2 = NULL;
182+
183+
while (N--) {
184+
s1 = (ss *)in1;
185+
s2 = (ss *)in2;
186+
if (ss_isnull(s1) || ss_isnull(s2)) {
187+
// s1 or s2 is NA
188+
*out = (npy_bool)0;
189+
}
190+
else if (s1->len == s2->len &&
191+
strncmp(s1->buf, s2->buf, s1->len) == 0) {
192+
*out = (npy_bool)0;
193+
}
194+
else {
195+
*out = (npy_bool)1;
196+
}
197+
198+
in1 += in1_stride;
199+
in2 += in2_stride;
200+
out += out_stride;
201+
}
202+
203+
return 0;
204+
}
205+
167206
static NPY_CASTING
168-
string_equal_resolve_descriptors(
207+
string_comparison_resolve_descriptors(
169208
struct PyArrayMethodObject_tag *NPY_UNUSED(method),
170209
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]), PyArray_Descr *given_descrs[],
171210
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
@@ -227,9 +266,10 @@ string_isnan_resolve_descriptors(
227266
* Copied from NumPy, because NumPy doesn't always use it :)
228267
*/
229268
static int
230-
default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
231-
PyArray_DTypeMeta *signature[],
232-
PyArray_DTypeMeta *new_op_dtypes[])
269+
ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
270+
PyArray_DTypeMeta *signature[],
271+
PyArray_DTypeMeta *new_op_dtypes[],
272+
PyArray_DTypeMeta *final_dtype)
233273
{
234274
/* If nin < 2 promotion is a no-op, so it should not be registered */
235275
assert(ufunc->nin > 1);
@@ -261,19 +301,11 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
261301
}
262302
}
263303
}
264-
/* Otherwise, use the common DType of all input operands */
265-
if (common == NULL) {
266-
common = PyArray_PromoteDTypeSequence(ufunc->nin, op_dtypes);
267-
if (common == NULL) {
268-
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
269-
PyErr_Clear(); /* Do not propagate normal promotion errors */
270-
}
271-
return -1;
272-
}
273-
}
304+
Py_XDECREF(common);
274305

306+
/* Otherwise, set all input operands to StringDType */
275307
for (int i = 0; i < ufunc->nargs; i++) {
276-
PyArray_DTypeMeta *tmp = common;
308+
PyArray_DTypeMeta *tmp = final_dtype;
277309
if (signature[i]) {
278310
tmp = signature[i]; /* never replace a fixed one. */
279311
}
@@ -285,10 +317,27 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
285317
new_op_dtypes[i] = op_dtypes[i];
286318
}
287319

288-
Py_DECREF(common);
289320
return 0;
290321
}
291322

323+
static int
324+
string_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
325+
PyArray_DTypeMeta *signature[],
326+
PyArray_DTypeMeta *new_op_dtypes[])
327+
{
328+
return ufunc_promoter_internal(ufunc, op_dtypes, signature, new_op_dtypes,
329+
(PyArray_DTypeMeta *)&StringDType);
330+
}
331+
332+
static int
333+
pandas_string_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
334+
PyArray_DTypeMeta *signature[],
335+
PyArray_DTypeMeta *new_op_dtypes[])
336+
{
337+
return ufunc_promoter_internal(ufunc, op_dtypes, signature, new_op_dtypes,
338+
(PyArray_DTypeMeta *)&PandasStringDType);
339+
}
340+
292341
// Register a ufunc.
293342
//
294343
// Pass NULL for resolve_func to use the default_resolve_descriptors.
@@ -334,23 +383,33 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
334383

335384
int
336385
add_promoter(PyObject *numpy, const char *ufunc_name,
337-
PyArray_DTypeMeta **dtypes)
386+
PyArray_DTypeMeta *ldtype, PyArray_DTypeMeta *rdtype,
387+
PyArray_DTypeMeta *edtype, int is_pandas)
338388
{
339389
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
340390

341391
if (ufunc == NULL) {
342392
return -1;
343393
}
344394

345-
PyObject *DType_tuple = PyTuple_Pack(3, dtypes[0], dtypes[1], dtypes[2]);
395+
PyObject *DType_tuple = PyTuple_Pack(3, ldtype, rdtype, edtype);
346396

347397
if (DType_tuple == NULL) {
348398
Py_DECREF(ufunc);
349399
return -1;
350400
}
351401

352-
PyObject *promoter_capsule = PyCapsule_New((void *)&default_ufunc_promoter,
353-
"numpy._ufunc_promoter", NULL);
402+
PyObject *promoter_capsule = NULL;
403+
404+
if (is_pandas == 0) {
405+
promoter_capsule = PyCapsule_New((void *)&string_ufunc_promoter,
406+
"numpy._ufunc_promoter", NULL);
407+
}
408+
else {
409+
promoter_capsule = PyCapsule_New((void *)&pandas_string_ufunc_promoter,
410+
"numpy._ufunc_promoter", NULL);
411+
}
412+
354413

355414
if (promoter_capsule == NULL) {
356415
Py_DECREF(ufunc);
@@ -380,30 +439,46 @@ init_ufuncs(void)
380439
return -1;
381440
}
382441

383-
PyArray_DTypeMeta *eq_dtypes[] = {(PyArray_DTypeMeta *)&StringDType,
384-
(PyArray_DTypeMeta *)&StringDType,
385-
&PyArray_BoolDType};
442+
PyArray_DTypeMeta *comparison_dtypes[] = {(PyArray_DTypeMeta *)&StringDType,
443+
(PyArray_DTypeMeta *)&StringDType,
444+
&PyArray_BoolDType};
386445

387-
if (init_ufunc(numpy, "equal", eq_dtypes,
388-
&string_equal_resolve_descriptors,
446+
if (init_ufunc(numpy, "equal", comparison_dtypes,
447+
&string_comparison_resolve_descriptors,
389448
&string_equal_strided_loop, "string_equal", 2, 1,
390449
NPY_NO_CASTING, 0) < 0) {
391450
goto error;
392451
}
393452

394-
PyArray_DTypeMeta *promoter_dtypes[2][3] = {
395-
{(PyArray_DTypeMeta *)&StringDType, &PyArray_UnicodeDType,
396-
&PyArray_BoolDType},
397-
{&PyArray_UnicodeDType, (PyArray_DTypeMeta *)&StringDType,
398-
&PyArray_BoolDType},
399-
};
400-
401-
if (add_promoter(numpy, "equal", promoter_dtypes[0]) < 0) {
453+
if (init_ufunc(numpy, "not_equal", comparison_dtypes,
454+
&string_comparison_resolve_descriptors,
455+
&string_not_equal_strided_loop, "string_not_equal", 2, 1,
456+
NPY_NO_CASTING, 0) < 0) {
402457
goto error;
403458
}
404459

405-
if (add_promoter(numpy, "equal", promoter_dtypes[1]) < 0) {
406-
goto error;
460+
char *ufunc_names[2] = {"equal", "not_equal"};
461+
462+
for (int i = 0; i < 2; i++) {
463+
if (add_promoter(numpy, ufunc_names[i], (PyArray_DTypeMeta *)&StringDType,
464+
&PyArray_UnicodeDType, &PyArray_BoolDType, 0) < 0) {
465+
goto error;
466+
}
467+
468+
if (add_promoter(numpy, ufunc_names[i], &PyArray_UnicodeDType,
469+
(PyArray_DTypeMeta *)&StringDType, &PyArray_BoolDType, 0) < 0) {
470+
goto error;
471+
}
472+
473+
if (add_promoter(numpy, ufunc_names[i], &PyArray_ObjectDType,
474+
(PyArray_DTypeMeta *)&StringDType, &PyArray_BoolDType, 0) < 0) {
475+
goto error;
476+
}
477+
478+
if (add_promoter(numpy, ufunc_names[i], (PyArray_DTypeMeta *)&StringDType,
479+
&PyArray_ObjectDType, &PyArray_BoolDType, 0) < 0) {
480+
goto error;
481+
}
407482
}
408483

409484
PyArray_DTypeMeta *isnan_dtypes[] = {(PyArray_DTypeMeta *)&StringDType,
@@ -448,30 +523,45 @@ init_ufuncs(void)
448523
goto finish;
449524
}
450525

451-
PyArray_DTypeMeta *peq_dtypes[] = {(PyArray_DTypeMeta *)&PandasStringDType,
452-
(PyArray_DTypeMeta *)&PandasStringDType,
453-
&PyArray_BoolDType};
526+
PyArray_DTypeMeta *p_comparison_dtypes[] =
527+
{(PyArray_DTypeMeta *)&PandasStringDType,
528+
(PyArray_DTypeMeta *)&PandasStringDType,
529+
&PyArray_BoolDType};
454530

455-
if (init_ufunc(numpy, "equal", peq_dtypes,
456-
&string_equal_resolve_descriptors,
531+
if (init_ufunc(numpy, "equal", p_comparison_dtypes,
532+
&string_comparison_resolve_descriptors,
457533
&string_equal_strided_loop, "string_equal", 2, 1,
458534
NPY_NO_CASTING, 0) < 0) {
459535
goto error;
460536
}
461537

462-
PyArray_DTypeMeta *p_promoter_dtypes[2][3] = {
463-
{(PyArray_DTypeMeta *)&PandasStringDType, &PyArray_UnicodeDType,
464-
&PyArray_BoolDType},
465-
{&PyArray_UnicodeDType, (PyArray_DTypeMeta *)&PandasStringDType,
466-
&PyArray_BoolDType},
467-
};
468-
469-
if (add_promoter(numpy, "equal", p_promoter_dtypes[0]) < 0) {
538+
if (init_ufunc(numpy, "not_equal", p_comparison_dtypes,
539+
&string_comparison_resolve_descriptors,
540+
&string_not_equal_strided_loop, "string_not_equal", 2, 1,
541+
NPY_NO_CASTING, 0) < 0) {
470542
goto error;
471543
}
472544

473-
if (add_promoter(numpy, "equal", p_promoter_dtypes[1]) < 0) {
474-
goto error;
545+
for (int i = 0; i < 2; i++) {
546+
if (add_promoter(numpy, ufunc_names[i], (PyArray_DTypeMeta *)&PandasStringDType,
547+
&PyArray_UnicodeDType, &PyArray_BoolDType, 1) < 0) {
548+
goto error;
549+
}
550+
551+
if (add_promoter(numpy, ufunc_names[i], &PyArray_UnicodeDType,
552+
(PyArray_DTypeMeta *)&PandasStringDType, &PyArray_BoolDType, 1) < 0) {
553+
goto error;
554+
}
555+
556+
if (add_promoter(numpy, ufunc_names[i], &PyArray_ObjectDType,
557+
(PyArray_DTypeMeta *)&PandasStringDType, &PyArray_BoolDType, 1) < 0) {
558+
goto error;
559+
}
560+
561+
if (add_promoter(numpy, ufunc_names[i], (PyArray_DTypeMeta *)&PandasStringDType,
562+
&PyArray_ObjectDType, &PyArray_BoolDType, 1) < 0) {
563+
goto error;
564+
}
475565
}
476566

477567
PyArray_DTypeMeta *p_isnan_dtypes[] = {

stringdtype/tests/test_stringdtype.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,15 @@ def test_insert_scalar(dtype, scalar, string_list):
142142
)
143143

144144

145-
def test_equality_promotion(dtype, string_list):
145+
@pytest.mark.parametrize("o_dtype", [np.str_, object])
146+
def test_equality_promotion(string_list, dtype, o_dtype):
146147
sarr = np.array(string_list, dtype=dtype)
147-
uarr = np.array(string_list, dtype=np.str_)
148+
oarr = np.array(string_list, dtype=o_dtype)
148149

149-
np.testing.assert_array_equal(sarr, uarr)
150-
np.testing.assert_array_equal(uarr, sarr)
150+
np.testing.assert_array_equal(sarr, oarr)
151+
np.testing.assert_array_equal(oarr, sarr)
152+
assert not np.any(sarr != oarr)
153+
assert not np.any(oarr != sarr)
151154

152155

153156
def test_isnan(dtype, string_list):

0 commit comments

Comments
 (0)