Skip to content

Commit 2646d7f

Browse files
committed
Add ufunc promoter for equal
1 parent 8ffd81c commit 2646d7f

File tree

3 files changed

+143
-23
lines changed

3 files changed

+143
-23
lines changed

stringdtype/stringdtype/src/dtype.c

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,23 @@ common_instance(StringDTypeObject *dtype1, StringDTypeObject *dtype2)
3434
return dtype1;
3535
}
3636

37+
/*
38+
* Used to determine the correct "common" dtype for promotion.
39+
* cls is always StringDType, other is an arbitrary other DType
40+
*/
3741
static PyArray_DTypeMeta *
3842
common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
3943
{
40-
// for now always raise an error here until we can figure out
41-
// how to deal with strings here
42-
PyErr_SetString(PyExc_RuntimeError, "common_dtype called in StringDType");
43-
return NULL;
44+
if (other->type_num == NPY_UNICODE) {
45+
/*
46+
* We have a cast from unicode, so allow unicode to promote
47+
* to StringDType
48+
*/
49+
Py_INCREF(cls);
50+
return cls;
51+
}
52+
Py_INCREF(Py_NotImplemented);
53+
return (PyArray_DTypeMeta *)Py_NotImplemented;
4454
}
4555

4656
// For a given python object, this function returns a borrowed reference

stringdtype/stringdtype/src/umath.c

Lines changed: 121 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,71 @@ string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
6060
return NPY_SAFE_CASTING;
6161
}
6262

63-
static char *equal_name = "string_equal";
63+
/*
64+
* Copied from NumPy, because NumPy doesn't always use it :)
65+
*/
66+
static int
67+
default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
68+
PyArray_DTypeMeta *signature[],
69+
PyArray_DTypeMeta *new_op_dtypes[])
70+
{
71+
/* If nin < 2 promotion is a no-op, so it should not be registered */
72+
assert(ufunc->nin > 1);
73+
if (op_dtypes[0] == NULL) {
74+
assert(ufunc->nin == 2 && ufunc->nout == 1); /* must be reduction */
75+
Py_INCREF(op_dtypes[1]);
76+
new_op_dtypes[0] = op_dtypes[1];
77+
Py_INCREF(op_dtypes[1]);
78+
new_op_dtypes[1] = op_dtypes[1];
79+
Py_INCREF(op_dtypes[1]);
80+
new_op_dtypes[2] = op_dtypes[1];
81+
return 0;
82+
}
83+
PyArray_DTypeMeta *common = NULL;
84+
/*
85+
* If a signature is used and homogeneous in its outputs use that
86+
* (Could/should likely be rather applied to inputs also, although outs
87+
* only could have some advantage and input dtypes are rarely enforced.)
88+
*/
89+
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
90+
if (signature[i] != NULL) {
91+
if (common == NULL) {
92+
Py_INCREF(signature[i]);
93+
common = signature[i];
94+
}
95+
else if (common != signature[i]) {
96+
Py_CLEAR(common); /* Not homogeneous, unset common */
97+
break;
98+
}
99+
}
100+
}
101+
/* Otherwise, use the common DType of all input operands */
102+
if (common == NULL) {
103+
common = PyArray_PromoteDTypeSequence(ufunc->nin, op_dtypes);
104+
if (common == NULL) {
105+
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
106+
PyErr_Clear(); /* Do not propagate normal promotion errors */
107+
}
108+
return -1;
109+
}
110+
}
111+
112+
for (int i = 0; i < ufunc->nargs; i++) {
113+
PyArray_DTypeMeta *tmp = common;
114+
if (signature[i]) {
115+
tmp = signature[i]; /* never replace a fixed one. */
116+
}
117+
Py_INCREF(tmp);
118+
new_op_dtypes[i] = tmp;
119+
}
120+
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
121+
Py_XINCREF(op_dtypes[i]);
122+
new_op_dtypes[i] = op_dtypes[i];
123+
}
124+
125+
Py_DECREF(common);
126+
return 0;
127+
}
64128

65129
int
66130
init_equal_ufunc(PyObject *numpy)
@@ -73,36 +137,74 @@ init_equal_ufunc(PyObject *numpy)
73137
/*
74138
* Initialize spec for equality
75139
*/
76-
PyArray_DTypeMeta **eq_dtypes = malloc(3 * sizeof(PyArray_DTypeMeta *));
77-
eq_dtypes[0] = &StringDType;
78-
eq_dtypes[1] = &StringDType;
79-
eq_dtypes[2] = &PyArray_BoolDType;
140+
PyArray_DTypeMeta *eq_dtypes[3] = {&StringDType, &StringDType,
141+
&PyArray_BoolDType};
80142

81143
static PyType_Slot eq_slots[] = {
82144
{NPY_METH_resolve_descriptors, &string_equal_resolve_descriptors},
83145
{NPY_METH_strided_loop, &string_equal_strided_loop},
84146
{0, NULL}};
85147

86-
PyArrayMethod_Spec *EqualSpec = malloc(sizeof(PyArrayMethod_Spec));
148+
PyArrayMethod_Spec EqualSpec = {
149+
.name = "string_equal",
150+
.nin = 2,
151+
.nout = 1,
152+
.casting = NPY_NO_CASTING,
153+
.flags = 0,
154+
.dtypes = eq_dtypes,
155+
.slots = eq_slots,
156+
};
157+
158+
if (PyUFunc_AddLoopFromSpec(equal, &EqualSpec) < 0) {
159+
Py_DECREF(equal);
160+
return -1;
161+
}
87162

88-
EqualSpec->name = equal_name;
89-
EqualSpec->nin = 2;
90-
EqualSpec->nout = 1;
91-
EqualSpec->casting = NPY_SAFE_CASTING;
92-
EqualSpec->flags = 0;
93-
EqualSpec->dtypes = eq_dtypes;
94-
EqualSpec->slots = eq_slots;
163+
/*
164+
* This might interfere with NumPy at this time.
165+
*/
166+
PyObject *promoter_capsule1 = PyCapsule_New(
167+
(void *)&default_ufunc_promoter, "numpy._ufunc_promoter", NULL);
168+
if (promoter_capsule1 == NULL) {
169+
return -1;
170+
}
95171

96-
if (PyUFunc_AddLoopFromSpec(equal, EqualSpec) < 0) {
97-
Py_DECREF(equal);
98-
free(eq_dtypes);
99-
free(EqualSpec);
172+
PyObject *DTypes1 = PyTuple_Pack(3, &StringDType, &PyArray_UnicodeDType,
173+
&PyArrayDescr_Type);
174+
if (DTypes1 == 0) {
175+
Py_DECREF(promoter_capsule1);
176+
return -1;
177+
}
178+
179+
if (PyUFunc_AddPromoter(equal, DTypes1, promoter_capsule1) < 0) {
180+
Py_DECREF(promoter_capsule1);
181+
Py_DECREF(DTypes1);
182+
return -1;
183+
}
184+
Py_DECREF(promoter_capsule1);
185+
Py_DECREF(DTypes1);
186+
187+
PyObject *promoter_capsule2 = PyCapsule_New(
188+
(void *)&default_ufunc_promoter, "numpy._ufunc_promoter", NULL);
189+
if (promoter_capsule2 == NULL) {
190+
return -1;
191+
}
192+
PyObject *DTypes2 = PyTuple_Pack(3, &PyArray_UnicodeDType, &StringDType,
193+
&PyArrayDescr_Type);
194+
if (DTypes2 == 0) {
195+
Py_DECREF(promoter_capsule2);
196+
return -1;
197+
}
198+
199+
if (PyUFunc_AddPromoter(equal, DTypes2, promoter_capsule2) < 0) {
200+
Py_DECREF(promoter_capsule2);
201+
Py_DECREF(DTypes2);
100202
return -1;
101203
}
204+
Py_DECREF(promoter_capsule2);
205+
Py_DECREF(DTypes2);
102206

103207
Py_DECREF(equal);
104-
free(eq_dtypes);
105-
free(EqualSpec);
106208
return 0;
107209
}
108210

stringdtype/tests/test_stringdtype.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,11 @@ def test_insert_scalar(string_list):
9696
arr = np.array(string_list, dtype=dtype)
9797
arr[1] = StringScalar("what", dtype=dtype)
9898
assert repr(arr) == repr(np.array(["abc", "what", "ghi"], dtype=dtype))
99+
100+
101+
def test_equality_promotion(string_list):
102+
sarr = np.array(string_list, dtype=StringDType())
103+
uarr = np.array(string_list, dtype=np.str_)
104+
105+
np.testing.assert_array_equal(sarr, uarr)
106+
np.testing.assert_array_equal(uarr, sarr)

0 commit comments

Comments
 (0)