Skip to content

Commit 9817861

Browse files
committed
ENH: introduce 'compatible' stringdtype instances
1 parent 229de79 commit 9817861

File tree

5 files changed

+145
-136
lines changed

5 files changed

+145
-136
lines changed

doc/neps/nep-0055-string_dtype.rst

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,11 +534,33 @@ future NumPy or a downstream library may add locale-aware sorting, case folding,
534534
and normalization for NumPy unicode strings arrays, but we are not proposing
535535
adding these features at this time.
536536

537-
Two ``StringDType`` instances are considered identical if they are created with
538-
the same ``na_object`` and ``coerce`` parameter. We propose checking for unequal
539-
``StringDType`` instances in the ``resolve_descriptors`` function of binary
540-
ufuncs that take two string arrays and raising an error if an operation is
541-
performed with unequal ``StringDType`` instances.
537+
Two ``StringDType`` instances are considered equal if they are created with the
538+
same ``na_object`` and ``coerce`` parameter. For ufuncs that accept more than
539+
one string argument we also introduce the concept of "compatible"
540+
``StringDType`` instances. We allow distinct DType instances to be used in ufunc
541+
operations together if have the same ``na_object`` or if only one
542+
or the other DType has an ``na_object`` explicitly set. We do not consider
543+
string coercion for determining whether instances are compatible, although if
544+
the result of the operation is a string, the result will inherit the stricter
545+
string coercion setting of the original operands.
546+
547+
This notion of "compatible" instances will be enforced in the
548+
``resolve_descriptors`` function of binary ufuncs. This choice makes it easier
549+
to work with non-default ``StringDType`` instances, because python strings are
550+
coerced to the default ``StringDType`` instance, so the following idiomatic
551+
expression is allowed::
552+
553+
>>> arr = np.array(["hello", "world"], dtype=StringDType(na_object=None))
554+
>>> arr + "!"
555+
array(['hello!', 'world!'], dtype=StringDType(na_object=None))
556+
557+
If we only considered equality of ``StringDType`` instances, this would
558+
be an error, making for an awkward user experience. If the operands have
559+
distinct ``na_object`` settings, NumPy will raise an error because the choice
560+
for the result DType is ambiguous::
561+
562+
>>> arr + np.array("!", dtype=StringDType(na_object=""))
563+
TypeError: Cannot find common instance for incompatible dtype instances
542564

543565
``np.strings`` namespace
544566
************************

numpy/_core/src/multiarray/stringdtype/dtype.c

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,21 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
144144
return NULL;
145145
}
146146

147-
static int
148-
na_eq_cmp(PyObject *a, PyObject *b) {
147+
NPY_NO_EXPORT int
148+
na_eq_cmp(PyObject *a, PyObject *b, int coerce_nulls) {
149149
if (a == b) {
150150
// catches None and other singletons like Pandas.NA
151151
return 1;
152152
}
153153
if (a == NULL || b == NULL) {
154-
return 0;
154+
if (coerce_nulls) {
155+
// an object with an explictly set NA object is considered
156+
// compatible for binary operations to one with no explicitly set NA
157+
return 1;
158+
}
159+
else {
160+
return 0;
161+
}
155162
}
156163
if (PyFloat_Check(a) && PyFloat_Check(b)) {
157164
// nan check catches np.nan and float('nan')
@@ -182,29 +189,29 @@ _eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona)
182189
if (scoerce != ocoerce) {
183190
return 0;
184191
}
185-
return na_eq_cmp(sna, ona);
192+
return na_eq_cmp(sna, ona, 0);
186193
}
187194

188195
/*
189196
* This is used to determine the correct dtype to return when dealing
190197
* with a mix of different dtypes (for example when creating an array
191198
* from a list of scalars).
192199
*/
193-
static PyArray_StringDTypeObject *
200+
NPY_NO_EXPORT PyArray_StringDTypeObject *
194201
common_instance(PyArray_StringDTypeObject *dtype1, PyArray_StringDTypeObject *dtype2)
195202
{
196-
int eq = _eq_comparison(dtype1->coerce, dtype2->coerce, dtype1->na_object,
197-
dtype2->na_object);
203+
int eq = na_eq_cmp(dtype1->na_object, dtype2->na_object, 1);
198204

199205
if (eq <= 0) {
200206
PyErr_SetString(
201-
PyExc_ValueError,
202-
"Cannot find common instance for unequal dtype instances");
207+
PyExc_TypeError,
208+
"Cannot find common instance for incompatible dtype instances");
203209
return NULL;
204210
}
205211

206212
return (PyArray_StringDTypeObject *)new_stringdtype_instance(
207-
dtype1->na_object, dtype1->coerce);
213+
dtype1->na_object != NULL ? dtype1->na_object : dtype2->na_object,
214+
!((dtype1->coerce == 0) || (dtype2->coerce == 0)));
208215
}
209216

210217
/*
@@ -280,7 +287,7 @@ stringdtype_setitem(PyArray_StringDTypeObject *descr, PyObject *obj, char **data
280287
{
281288
npy_packed_static_string *sdata = (npy_packed_static_string *)dataptr;
282289

283-
int is_cmp = 0;
290+
int na_cmp = 0;
284291

285292
// borrow reference
286293
PyObject *na_object = descr->na_object;
@@ -294,16 +301,16 @@ stringdtype_setitem(PyArray_StringDTypeObject *descr, PyObject *obj, char **data
294301
// so we do the comparison before acquiring the allocator.
295302

296303
if (na_object != NULL) {
297-
is_cmp = na_eq_cmp(obj, na_object);
298-
if (is_cmp == -1) {
304+
na_cmp = na_eq_cmp(obj, na_object, 1);
305+
if (na_cmp == -1) {
299306
return -1;
300307
}
301308
}
302309

303310
npy_string_allocator *allocator = NpyString_acquire_allocator(descr);
304311

305312
if (na_object != NULL) {
306-
if (is_cmp) {
313+
if (na_cmp) {
307314
if (NpyString_pack_null(allocator, sdata) < 0) {
308315
PyErr_SetString(PyExc_MemoryError,
309316
"Failed to pack null string during StringDType "

numpy/_core/src/multiarray/stringdtype/dtype.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ stringdtype_finalize_descr(PyArray_Descr *dtype);
4949
NPY_NO_EXPORT int
5050
_eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona);
5151

52+
NPY_NO_EXPORT PyArray_StringDTypeObject *
53+
common_instance(PyArray_StringDTypeObject *dtype1, PyArray_StringDTypeObject *dtype2);
54+
5255
#ifdef __cplusplus
5356
}
5457
#endif

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 38 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -246,20 +246,9 @@ binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
246246
{
247247
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
248248
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
249+
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
249250

250-
// _eq_comparison has a short-circuit pointer comparison fast path,
251-
// so no need to check here
252-
int eq_res = _eq_comparison(descr1->coerce, descr2->coerce,
253-
descr1->na_object, descr2->na_object);
254-
255-
if (eq_res < 0) {
256-
return (NPY_CASTING)-1;
257-
}
258-
259-
if (eq_res != 1) {
260-
PyErr_SetString(PyExc_TypeError,
261-
"Can only do binary operations with equal StringDType "
262-
"instances.");
251+
if (common_descr == NULL) {
263252
return (NPY_CASTING)-1;
264253
}
265254

@@ -272,8 +261,7 @@ binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
272261

273262
if (given_descrs[2] == NULL) {
274263
out_descr = (PyArray_Descr *)new_stringdtype_instance(
275-
((PyArray_StringDTypeObject *)given_descrs[1])->na_object,
276-
((PyArray_StringDTypeObject *)given_descrs[1])->coerce);
264+
common_descr->na_object, common_descr->coerce);
277265

278266
if (out_descr == NULL) {
279267
return (NPY_CASTING)-1;
@@ -562,6 +550,14 @@ string_comparison_resolve_descriptors(
562550
PyArray_Descr *const given_descrs[],
563551
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
564552
{
553+
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
554+
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
555+
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
556+
557+
if (common_descr == NULL) {
558+
return (NPY_CASTING)-1;
559+
}
560+
565561
Py_INCREF(given_descrs[0]);
566562
loop_descrs[0] = given_descrs[0];
567563
Py_INCREF(given_descrs[1]);
@@ -788,20 +784,9 @@ string_findlike_resolve_descriptors(
788784
{
789785
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
790786
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
787+
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
791788

792-
// _eq_comparison has a short-circuit pointer comparison fast path,
793-
// so no need to check here
794-
int eq_res = _eq_comparison(descr1->coerce, descr2->coerce,
795-
descr1->na_object, descr2->na_object);
796-
797-
if (eq_res < 0) {
798-
return (NPY_CASTING)-1;
799-
}
800-
801-
if (eq_res != 1) {
802-
PyErr_SetString(PyExc_TypeError,
803-
"Can only do binary operations with equal StringDType "
804-
"instances.");
789+
if (common_descr == NULL) {
805790
return (NPY_CASTING)-1;
806791
}
807792

@@ -849,20 +834,9 @@ string_startswith_endswith_resolve_descriptors(
849834
{
850835
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
851836
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
837+
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
852838

853-
// _eq_comparison has a short-circuit pointer comparison fast path, so
854-
// no need to do it here
855-
int eq_res = _eq_comparison(descr1->coerce, descr2->coerce,
856-
descr1->na_object, descr2->na_object);
857-
858-
if (eq_res < 0) {
859-
return (NPY_CASTING)-1;
860-
}
861-
862-
if (eq_res != 1) {
863-
PyErr_SetString(PyExc_TypeError,
864-
"Can only do binary operations with equal StringDType "
865-
"instances.");
839+
if (common_descr == NULL) {
866840
return (NPY_CASTING)-1;
867841
}
868842

@@ -1061,46 +1035,6 @@ all_strings_promoter(PyObject *NPY_UNUSED(ufunc),
10611035
return 0;
10621036
}
10631037

1064-
static NPY_CASTING
1065-
strip_chars_resolve_descriptors(
1066-
struct PyArrayMethodObject_tag *NPY_UNUSED(method),
1067-
PyArray_DTypeMeta *const NPY_UNUSED(dtypes[]),
1068-
PyArray_Descr *const given_descrs[],
1069-
PyArray_Descr *loop_descrs[],
1070-
npy_intp *NPY_UNUSED(view_offset))
1071-
{
1072-
Py_INCREF(given_descrs[0]);
1073-
loop_descrs[0] = given_descrs[0];
1074-
1075-
// we don't actually care about the null behavior of the second argument,
1076-
// so no need to check if the first two descrs are equal like in
1077-
// binary_resolve_descriptors
1078-
1079-
Py_INCREF(given_descrs[1]);
1080-
loop_descrs[1] = given_descrs[1];
1081-
1082-
PyArray_Descr *out_descr = NULL;
1083-
1084-
if (given_descrs[2] == NULL) {
1085-
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1086-
((PyArray_StringDTypeObject *)given_descrs[0])->na_object,
1087-
((PyArray_StringDTypeObject *)given_descrs[0])->coerce);
1088-
1089-
if (out_descr == NULL) {
1090-
return (NPY_CASTING)-1;
1091-
}
1092-
}
1093-
else {
1094-
Py_INCREF(given_descrs[2]);
1095-
out_descr = given_descrs[2];
1096-
}
1097-
1098-
loop_descrs[2] = out_descr;
1099-
1100-
return NPY_NO_CASTING;
1101-
}
1102-
1103-
11041038
NPY_NO_EXPORT int
11051039
string_lrstrip_chars_strided_loop(
11061040
PyArrayMethod_Context *context, char *const data[],
@@ -1309,21 +1243,10 @@ replace_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
13091243
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
13101244
PyArray_StringDTypeObject *descr3 = (PyArray_StringDTypeObject *)given_descrs[2];
13111245

1312-
// _eq_comparison has a short-circuit pointer comparison fast path, so
1313-
// no need to do it here
1314-
int eq_res = (_eq_comparison(descr1->coerce, descr2->coerce,
1315-
descr1->na_object, descr2->na_object) &&
1316-
_eq_comparison(descr1->coerce, descr3->coerce,
1317-
descr1->na_object, descr3->na_object));
1318-
1319-
if (eq_res < 0) {
1320-
return (NPY_CASTING)-1;
1321-
}
1246+
PyArray_StringDTypeObject *common_descr = common_instance(
1247+
common_instance(descr1, descr2), descr3);
13221248

1323-
if (eq_res != 1) {
1324-
PyErr_SetString(PyExc_TypeError,
1325-
"String replace is only supported with equal StringDType "
1326-
"instances.");
1249+
if (common_descr == NULL) {
13271250
return (NPY_CASTING)-1;
13281251
}
13291252

@@ -1340,8 +1263,7 @@ replace_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
13401263

13411264
if (given_descrs[4] == NULL) {
13421265
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1343-
((PyArray_StringDTypeObject *)given_descrs[0])->na_object,
1344-
((PyArray_StringDTypeObject *)given_descrs[0])->coerce);
1266+
common_descr->na_object, common_descr->coerce);
13451267

13461268
if (out_descr == NULL) {
13471269
return (NPY_CASTING)-1;
@@ -1588,18 +1510,9 @@ center_ljust_rjust_resolve_descriptors(
15881510
{
15891511
PyArray_StringDTypeObject *input_descr = (PyArray_StringDTypeObject *)given_descrs[0];
15901512
PyArray_StringDTypeObject *fill_descr = (PyArray_StringDTypeObject *)given_descrs[2];
1513+
PyArray_StringDTypeObject *common_descr = common_instance(input_descr, fill_descr);
15911514

1592-
int eq_res = _eq_comparison(input_descr->coerce, fill_descr->coerce,
1593-
input_descr->na_object, fill_descr->na_object);
1594-
1595-
if (eq_res < 0) {
1596-
return (NPY_CASTING)-1;
1597-
}
1598-
1599-
if (eq_res != 1) {
1600-
PyErr_SetString(PyExc_TypeError,
1601-
"Can only do text justification operations with equal"
1602-
"StringDType instances.");
1515+
if (common_descr == NULL) {
16031516
return (NPY_CASTING)-1;
16041517
}
16051518

@@ -1614,8 +1527,7 @@ center_ljust_rjust_resolve_descriptors(
16141527

16151528
if (given_descrs[3] == NULL) {
16161529
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1617-
((PyArray_StringDTypeObject *)given_descrs[1])->na_object,
1618-
((PyArray_StringDTypeObject *)given_descrs[1])->coerce);
1530+
common_descr->na_object, common_descr->coerce);
16191531

16201532
if (out_descr == NULL) {
16211533
return (NPY_CASTING)-1;
@@ -1888,6 +1800,7 @@ zfill_strided_loop(PyArrayMethod_Context *context,
18881800
return -1;
18891801
}
18901802

1803+
18911804
static NPY_CASTING
18921805
string_partition_resolve_descriptors(
18931806
PyArrayMethodObject *self,
@@ -1901,14 +1814,23 @@ string_partition_resolve_descriptors(
19011814
"currently support the 'out' keyword", self->name);
19021815
return (NPY_CASTING)-1;
19031816
}
1904-
for (int i=0; i<2; i++) {
1905-
Py_INCREF(given_descrs[i]);
1906-
loop_descrs[i] = given_descrs[i];
1817+
1818+
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
1819+
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
1820+
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
1821+
1822+
if (common_descr == NULL) {
1823+
return (NPY_CASTING)-1;
19071824
}
1908-
PyArray_StringDTypeObject *adescr = (PyArray_StringDTypeObject *)given_descrs[0];
1825+
1826+
Py_INCREF(given_descrs[0]);
1827+
loop_descrs[0] = given_descrs[0];
1828+
Py_INCREF(given_descrs[1]);
1829+
loop_descrs[1] = given_descrs[1];
1830+
19091831
for (int i=2; i<5; i++) {
19101832
loop_descrs[i] = (PyArray_Descr *)new_stringdtype_instance(
1911-
adescr->na_object, adescr->coerce);
1833+
common_descr->na_object, common_descr->coerce);
19121834
if (loop_descrs[i] == NULL) {
19131835
return (NPY_CASTING)-1;
19141836
}
@@ -2655,7 +2577,7 @@ init_stringdtype_ufuncs(PyObject *umath)
26552577

26562578
for (int i=0; i<3; i++) {
26572579
if (init_ufunc(umath, strip_chars_names[i], strip_chars_dtypes,
2658-
&strip_chars_resolve_descriptors,
2580+
&binary_resolve_descriptors,
26592581
&string_lrstrip_chars_strided_loop,
26602582
2, 1, NPY_NO_CASTING, (NPY_ARRAYMETHOD_FLAGS) 0,
26612583
&strip_types[i]) < 0) {

0 commit comments

Comments
 (0)