Skip to content

Commit e808a0f

Browse files
authored
Merge pull request #27 from ngoldbaum/add-promoter
Add ufunc promoter for equal
2 parents 8ffd81c + cb5d0d3 commit e808a0f

File tree

3 files changed

+133
-23
lines changed

3 files changed

+133
-23
lines changed

stringdtype/stringdtype/src/dtype.c

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,23 @@ common_instance(StringDTypeObject *dtype1, StringDTypeObject *dtype2)
3434
return dtype1;
3535
}
3636

37+
/*
38+
* Used to determine the correct "common" dtype for promotion.
39+
* cls is always StringDType, other is an arbitrary other DType
40+
*/
3741
static PyArray_DTypeMeta *
3842
common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
3943
{
40-
// for now always raise an error here until we can figure out
41-
// how to deal with strings here
42-
PyErr_SetString(PyExc_RuntimeError, "common_dtype called in StringDType");
43-
return NULL;
44+
if (other->type_num == NPY_UNICODE) {
45+
/*
46+
* We have a cast from unicode, so allow unicode to promote
47+
* to StringDType
48+
*/
49+
Py_INCREF(cls);
50+
return cls;
51+
}
52+
Py_INCREF(Py_NotImplemented);
53+
return (PyArray_DTypeMeta *)Py_NotImplemented;
4454
}
4555

4656
// For a given python object, this function returns a borrowed reference

stringdtype/stringdtype/src/umath.c

Lines changed: 111 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,71 @@ string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
6060
return NPY_SAFE_CASTING;
6161
}
6262

63-
static char *equal_name = "string_equal";
63+
/*
64+
* Copied from NumPy, because NumPy doesn't always use it :)
65+
*/
66+
static int
67+
default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
68+
PyArray_DTypeMeta *signature[],
69+
PyArray_DTypeMeta *new_op_dtypes[])
70+
{
71+
/* If nin < 2 promotion is a no-op, so it should not be registered */
72+
assert(ufunc->nin > 1);
73+
if (op_dtypes[0] == NULL) {
74+
assert(ufunc->nin == 2 && ufunc->nout == 1); /* must be reduction */
75+
Py_INCREF(op_dtypes[1]);
76+
new_op_dtypes[0] = op_dtypes[1];
77+
Py_INCREF(op_dtypes[1]);
78+
new_op_dtypes[1] = op_dtypes[1];
79+
Py_INCREF(op_dtypes[1]);
80+
new_op_dtypes[2] = op_dtypes[1];
81+
return 0;
82+
}
83+
PyArray_DTypeMeta *common = NULL;
84+
/*
85+
* If a signature is used and homogeneous in its outputs use that
86+
* (Could/should likely be rather applied to inputs also, although outs
87+
* only could have some advantage and input dtypes are rarely enforced.)
88+
*/
89+
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
90+
if (signature[i] != NULL) {
91+
if (common == NULL) {
92+
Py_INCREF(signature[i]);
93+
common = signature[i];
94+
}
95+
else if (common != signature[i]) {
96+
Py_CLEAR(common); /* Not homogeneous, unset common */
97+
break;
98+
}
99+
}
100+
}
101+
/* Otherwise, use the common DType of all input operands */
102+
if (common == NULL) {
103+
common = PyArray_PromoteDTypeSequence(ufunc->nin, op_dtypes);
104+
if (common == NULL) {
105+
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
106+
PyErr_Clear(); /* Do not propagate normal promotion errors */
107+
}
108+
return -1;
109+
}
110+
}
111+
112+
for (int i = 0; i < ufunc->nargs; i++) {
113+
PyArray_DTypeMeta *tmp = common;
114+
if (signature[i]) {
115+
tmp = signature[i]; /* never replace a fixed one. */
116+
}
117+
Py_INCREF(tmp);
118+
new_op_dtypes[i] = tmp;
119+
}
120+
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
121+
Py_XINCREF(op_dtypes[i]);
122+
new_op_dtypes[i] = op_dtypes[i];
123+
}
124+
125+
Py_DECREF(common);
126+
return 0;
127+
}
64128

65129
int
66130
init_equal_ufunc(PyObject *numpy)
@@ -71,38 +135,66 @@ init_equal_ufunc(PyObject *numpy)
71135
}
72136

73137
/*
74-
* Initialize spec for equality
138+
* Initialize spec for equality
75139
*/
76-
PyArray_DTypeMeta **eq_dtypes = malloc(3 * sizeof(PyArray_DTypeMeta *));
77-
eq_dtypes[0] = &StringDType;
78-
eq_dtypes[1] = &StringDType;
79-
eq_dtypes[2] = &PyArray_BoolDType;
140+
PyArray_DTypeMeta *eq_dtypes[3] = {&StringDType, &StringDType,
141+
&PyArray_BoolDType};
80142

81143
static PyType_Slot eq_slots[] = {
82144
{NPY_METH_resolve_descriptors, &string_equal_resolve_descriptors},
83145
{NPY_METH_strided_loop, &string_equal_strided_loop},
84146
{0, NULL}};
85147

86-
PyArrayMethod_Spec *EqualSpec = malloc(sizeof(PyArrayMethod_Spec));
148+
PyArrayMethod_Spec EqualSpec = {
149+
.name = "string_equal",
150+
.nin = 2,
151+
.nout = 1,
152+
.casting = NPY_NO_CASTING,
153+
.flags = 0,
154+
.dtypes = eq_dtypes,
155+
.slots = eq_slots,
156+
};
157+
158+
if (PyUFunc_AddLoopFromSpec(equal, &EqualSpec) < 0) {
159+
Py_DECREF(equal);
160+
return -1;
161+
}
87162

88-
EqualSpec->name = equal_name;
89-
EqualSpec->nin = 2;
90-
EqualSpec->nout = 1;
91-
EqualSpec->casting = NPY_SAFE_CASTING;
92-
EqualSpec->flags = 0;
93-
EqualSpec->dtypes = eq_dtypes;
94-
EqualSpec->slots = eq_slots;
163+
/*
164+
* Add promoter to ufunc, ensures operations that mix StringDType and
165+
* UnicodeDType cast the unicode argument to string.
166+
*/
95167

96-
if (PyUFunc_AddLoopFromSpec(equal, EqualSpec) < 0) {
168+
PyObject *DTypes[] = {
169+
PyTuple_Pack(3, &StringDType, &PyArray_UnicodeDType,
170+
&PyArray_BoolDType),
171+
PyTuple_Pack(3, &PyArray_UnicodeDType, &StringDType,
172+
&PyArray_BoolDType),
173+
};
174+
175+
if ((DTypes[0] == NULL) || (DTypes[1] == NULL)) {
97176
Py_DECREF(equal);
98-
free(eq_dtypes);
99-
free(EqualSpec);
100177
return -1;
101178
}
102179

180+
PyObject *promoter_capsule = PyCapsule_New((void *)&default_ufunc_promoter,
181+
"numpy._ufunc_promoter", NULL);
182+
183+
for (int i = 0; i < 2; i++) {
184+
if (PyUFunc_AddPromoter(equal, DTypes[i], promoter_capsule) < 0) {
185+
Py_DECREF(promoter_capsule);
186+
Py_DECREF(DTypes[0]);
187+
Py_DECREF(DTypes[1]);
188+
Py_DECREF(equal);
189+
return -1;
190+
}
191+
}
192+
193+
Py_DECREF(promoter_capsule);
194+
Py_DECREF(DTypes[0]);
195+
Py_DECREF(DTypes[1]);
103196
Py_DECREF(equal);
104-
free(eq_dtypes);
105-
free(EqualSpec);
197+
106198
return 0;
107199
}
108200

stringdtype/tests/test_stringdtype.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,11 @@ def test_insert_scalar(string_list):
9696
arr = np.array(string_list, dtype=dtype)
9797
arr[1] = StringScalar("what", dtype=dtype)
9898
assert repr(arr) == repr(np.array(["abc", "what", "ghi"], dtype=dtype))
99+
100+
101+
def test_equality_promotion(string_list):
102+
sarr = np.array(string_list, dtype=StringDType())
103+
uarr = np.array(string_list, dtype=np.str_)
104+
105+
np.testing.assert_array_equal(sarr, uarr)
106+
np.testing.assert_array_equal(uarr, sarr)

0 commit comments

Comments
 (0)