Skip to content

Commit 3c09f16

Browse files
authored
Merge pull request numpy#26261 from ngoldbaum/compatible-stringdtype-ufuncs
ENH: introduce a notion of "compatible" stringdtype instances
2 parents 75b5bf1 + 019ae02 commit 3c09f16

File tree

6 files changed

+167
-143
lines changed

6 files changed

+167
-143
lines changed

doc/neps/nep-0055-string_dtype.rst

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,11 +534,33 @@ future NumPy or a downstream library may add locale-aware sorting, case folding,
534534
and normalization for NumPy unicode strings arrays, but we are not proposing
535535
adding these features at this time.
536536

537-
Two ``StringDType`` instances are considered identical if they are created with
538-
the same ``na_object`` and ``coerce`` parameter. We propose checking for unequal
539-
``StringDType`` instances in the ``resolve_descriptors`` function of binary
540-
ufuncs that take two string arrays and raising an error if an operation is
541-
performed with unequal ``StringDType`` instances.
537+
Two ``StringDType`` instances are considered equal if they are created with the
538+
same ``na_object`` and ``coerce`` parameter. For ufuncs that accept more than
539+
one string argument we also introduce the concept of "compatible"
540+
``StringDType`` instances. We allow distinct DType instances to be used in ufunc
541+
operations together if have the same ``na_object`` or if only one
542+
or the other DType has an ``na_object`` explicitly set. We do not consider
543+
string coercion for determining whether instances are compatible, although if
544+
the result of the operation is a string, the result will inherit the stricter
545+
string coercion setting of the original operands.
546+
547+
This notion of "compatible" instances will be enforced in the
548+
``resolve_descriptors`` function of binary ufuncs. This choice makes it easier
549+
to work with non-default ``StringDType`` instances, because python strings are
550+
coerced to the default ``StringDType`` instance, so the following idiomatic
551+
expression is allowed::
552+
553+
>>> arr = np.array(["hello", "world"], dtype=StringDType(na_object=None))
554+
>>> arr + "!"
555+
array(['hello!', 'world!'], dtype=StringDType(na_object=None))
556+
557+
If we only considered equality of ``StringDType`` instances, this would
558+
be an error, making for an awkward user experience. If the operands have
559+
distinct ``na_object`` settings, NumPy will raise an error because the choice
560+
for the result DType is ambiguous::
561+
562+
>>> arr + np.array("!", dtype=StringDType(na_object=""))
563+
TypeError: Cannot find common instance for incompatible dtype instances
542564

543565
``np.strings`` namespace
544566
************************

numpy/_core/src/multiarray/stringdtype/dtype.c

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,30 @@ _eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona)
185185
return na_eq_cmp(sna, ona);
186186
}
187187

188+
// Currently this can only return 0 or -1, the latter indicating that the
189+
// error indicator is set. Pass in out_na if you want to figure out which
190+
// na is valid.
191+
NPY_NO_EXPORT int
192+
stringdtype_compatible_na(PyObject *na1, PyObject *na2, PyObject **out_na) {
193+
if ((na1 != NULL) && (na2 != NULL)) {
194+
int na_eq = na_eq_cmp(na1, na2);
195+
196+
if (na_eq < 0) {
197+
return -1;
198+
}
199+
else if (na_eq == 0) {
200+
PyErr_Format(PyExc_TypeError,
201+
"Cannot find a compatible null string value for "
202+
"null strings '%R' and '%R'", na1, na2);
203+
return -1;
204+
}
205+
}
206+
if (out_na != NULL) {
207+
*out_na = na1 ? na1 : na2;
208+
}
209+
return 0;
210+
}
211+
188212
/*
189213
* This is used to determine the correct dtype to return when dealing
190214
* with a mix of different dtypes (for example when creating an array
@@ -193,18 +217,18 @@ _eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona)
193217
static PyArray_StringDTypeObject *
194218
common_instance(PyArray_StringDTypeObject *dtype1, PyArray_StringDTypeObject *dtype2)
195219
{
196-
int eq = _eq_comparison(dtype1->coerce, dtype2->coerce, dtype1->na_object,
197-
dtype2->na_object);
220+
PyObject *out_na_object = NULL;
198221

199-
if (eq <= 0) {
200-
PyErr_SetString(
201-
PyExc_ValueError,
202-
"Cannot find common instance for unequal dtype instances");
222+
if (stringdtype_compatible_na(
223+
dtype1->na_object, dtype2->na_object, &out_na_object) == -1) {
224+
PyErr_Format(PyExc_TypeError,
225+
"Cannot find common instance for incompatible dtypes "
226+
"'%R' and '%R'", (PyObject *)dtype1, (PyObject *)dtype2);
203227
return NULL;
204228
}
205229

206230
return (PyArray_StringDTypeObject *)new_stringdtype_instance(
207-
dtype1->na_object, dtype1->coerce);
231+
out_na_object, dtype1->coerce && dtype1->coerce);
208232
}
209233

210234
/*
@@ -280,30 +304,22 @@ stringdtype_setitem(PyArray_StringDTypeObject *descr, PyObject *obj, char **data
280304
{
281305
npy_packed_static_string *sdata = (npy_packed_static_string *)dataptr;
282306

283-
int is_cmp = 0;
284-
285307
// borrow reference
286308
PyObject *na_object = descr->na_object;
287309

288-
// Note there are two different na_object != NULL checks here.
289-
//
290-
// Do not refactor this!
291-
//
292310
// We need the result of the comparison after acquiring the allocator, but
293311
// cannot use functions requiring the GIL when the allocator is acquired,
294312
// so we do the comparison before acquiring the allocator.
295313

296-
if (na_object != NULL) {
297-
is_cmp = na_eq_cmp(obj, na_object);
298-
if (is_cmp == -1) {
299-
return -1;
300-
}
314+
int na_cmp = na_eq_cmp(obj, na_object);
315+
if (na_cmp == -1) {
316+
return -1;
301317
}
302318

303319
npy_string_allocator *allocator = NpyString_acquire_allocator(descr);
304320

305321
if (na_object != NULL) {
306-
if (is_cmp) {
322+
if (na_cmp) {
307323
if (NpyString_pack_null(allocator, sdata) < 0) {
308324
PyErr_SetString(PyExc_MemoryError,
309325
"Failed to pack null string during StringDType "

numpy/_core/src/multiarray/stringdtype/dtype.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ stringdtype_finalize_descr(PyArray_Descr *dtype);
4949
NPY_NO_EXPORT int
5050
_eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona);
5151

52+
NPY_NO_EXPORT int
53+
stringdtype_compatible_na(PyObject *na1, PyObject *na2, PyObject **out_na);
54+
5255
#ifdef __cplusplus
5356
}
5457
#endif

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 44 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -246,20 +246,11 @@ binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
246246
{
247247
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
248248
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
249+
int out_coerce = descr1->coerce && descr1->coerce;
250+
PyObject *out_na_object = NULL;
249251

250-
// _eq_comparison has a short-circuit pointer comparison fast path,
251-
// so no need to check here
252-
int eq_res = _eq_comparison(descr1->coerce, descr2->coerce,
253-
descr1->na_object, descr2->na_object);
254-
255-
if (eq_res < 0) {
256-
return (NPY_CASTING)-1;
257-
}
258-
259-
if (eq_res != 1) {
260-
PyErr_SetString(PyExc_TypeError,
261-
"Can only do binary operations with equal StringDType "
262-
"instances.");
252+
if (stringdtype_compatible_na(
253+
descr1->na_object, descr2->na_object, &out_na_object) == -1) {
263254
return (NPY_CASTING)-1;
264255
}
265256

@@ -272,8 +263,7 @@ binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
272263

273264
if (given_descrs[2] == NULL) {
274265
out_descr = (PyArray_Descr *)new_stringdtype_instance(
275-
((PyArray_StringDTypeObject *)given_descrs[1])->na_object,
276-
((PyArray_StringDTypeObject *)given_descrs[1])->coerce);
266+
out_na_object, out_coerce);
277267

278268
if (out_descr == NULL) {
279269
return (NPY_CASTING)-1;
@@ -562,6 +552,13 @@ string_comparison_resolve_descriptors(
562552
PyArray_Descr *const given_descrs[],
563553
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
564554
{
555+
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
556+
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
557+
558+
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object, NULL) == -1) {
559+
return (NPY_CASTING)-1;
560+
}
561+
565562
Py_INCREF(given_descrs[0]);
566563
loop_descrs[0] = given_descrs[0];
567564
Py_INCREF(given_descrs[1]);
@@ -789,19 +786,7 @@ string_findlike_resolve_descriptors(
789786
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
790787
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
791788

792-
// _eq_comparison has a short-circuit pointer comparison fast path,
793-
// so no need to check here
794-
int eq_res = _eq_comparison(descr1->coerce, descr2->coerce,
795-
descr1->na_object, descr2->na_object);
796-
797-
if (eq_res < 0) {
798-
return (NPY_CASTING)-1;
799-
}
800-
801-
if (eq_res != 1) {
802-
PyErr_SetString(PyExc_TypeError,
803-
"Can only do binary operations with equal StringDType "
804-
"instances.");
789+
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object, NULL) == -1) {
805790
return (NPY_CASTING)-1;
806791
}
807792

@@ -850,19 +835,7 @@ string_startswith_endswith_resolve_descriptors(
850835
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
851836
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
852837

853-
// _eq_comparison has a short-circuit pointer comparison fast path, so
854-
// no need to do it here
855-
int eq_res = _eq_comparison(descr1->coerce, descr2->coerce,
856-
descr1->na_object, descr2->na_object);
857-
858-
if (eq_res < 0) {
859-
return (NPY_CASTING)-1;
860-
}
861-
862-
if (eq_res != 1) {
863-
PyErr_SetString(PyExc_TypeError,
864-
"Can only do binary operations with equal StringDType "
865-
"instances.");
838+
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object, NULL) == -1) {
866839
return (NPY_CASTING)-1;
867840
}
868841

@@ -1061,46 +1034,6 @@ all_strings_promoter(PyObject *NPY_UNUSED(ufunc),
10611034
return 0;
10621035
}
10631036

1064-
static NPY_CASTING
1065-
strip_chars_resolve_descriptors(
1066-
struct PyArrayMethodObject_tag *NPY_UNUSED(method),
1067-
PyArray_DTypeMeta *const NPY_UNUSED(dtypes[]),
1068-
PyArray_Descr *const given_descrs[],
1069-
PyArray_Descr *loop_descrs[],
1070-
npy_intp *NPY_UNUSED(view_offset))
1071-
{
1072-
Py_INCREF(given_descrs[0]);
1073-
loop_descrs[0] = given_descrs[0];
1074-
1075-
// we don't actually care about the null behavior of the second argument,
1076-
// so no need to check if the first two descrs are equal like in
1077-
// binary_resolve_descriptors
1078-
1079-
Py_INCREF(given_descrs[1]);
1080-
loop_descrs[1] = given_descrs[1];
1081-
1082-
PyArray_Descr *out_descr = NULL;
1083-
1084-
if (given_descrs[2] == NULL) {
1085-
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1086-
((PyArray_StringDTypeObject *)given_descrs[0])->na_object,
1087-
((PyArray_StringDTypeObject *)given_descrs[0])->coerce);
1088-
1089-
if (out_descr == NULL) {
1090-
return (NPY_CASTING)-1;
1091-
}
1092-
}
1093-
else {
1094-
Py_INCREF(given_descrs[2]);
1095-
out_descr = given_descrs[2];
1096-
}
1097-
1098-
loop_descrs[2] = out_descr;
1099-
1100-
return NPY_NO_CASTING;
1101-
}
1102-
1103-
11041037
NPY_NO_EXPORT int
11051038
string_lrstrip_chars_strided_loop(
11061039
PyArrayMethod_Context *context, char *const data[],
@@ -1308,22 +1241,16 @@ replace_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
13081241
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
13091242
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
13101243
PyArray_StringDTypeObject *descr3 = (PyArray_StringDTypeObject *)given_descrs[2];
1244+
int out_coerce = descr1->coerce && descr2->coerce && descr3->coerce;
1245+
PyObject *out_na_object = NULL;
13111246

1312-
// _eq_comparison has a short-circuit pointer comparison fast path, so
1313-
// no need to do it here
1314-
int eq_res = (_eq_comparison(descr1->coerce, descr2->coerce,
1315-
descr1->na_object, descr2->na_object) &&
1316-
_eq_comparison(descr1->coerce, descr3->coerce,
1317-
descr1->na_object, descr3->na_object));
1318-
1319-
if (eq_res < 0) {
1247+
if (stringdtype_compatible_na(
1248+
descr1->na_object, descr2->na_object, &out_na_object) == -1) {
13201249
return (NPY_CASTING)-1;
13211250
}
13221251

1323-
if (eq_res != 1) {
1324-
PyErr_SetString(PyExc_TypeError,
1325-
"String replace is only supported with equal StringDType "
1326-
"instances.");
1252+
if (stringdtype_compatible_na(
1253+
out_na_object, descr3->na_object, &out_na_object) == -1) {
13271254
return (NPY_CASTING)-1;
13281255
}
13291256

@@ -1340,8 +1267,7 @@ replace_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
13401267

13411268
if (given_descrs[4] == NULL) {
13421269
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1343-
((PyArray_StringDTypeObject *)given_descrs[0])->na_object,
1344-
((PyArray_StringDTypeObject *)given_descrs[0])->coerce);
1270+
out_na_object, out_coerce);
13451271

13461272
if (out_descr == NULL) {
13471273
return (NPY_CASTING)-1;
@@ -1588,18 +1514,11 @@ center_ljust_rjust_resolve_descriptors(
15881514
{
15891515
PyArray_StringDTypeObject *input_descr = (PyArray_StringDTypeObject *)given_descrs[0];
15901516
PyArray_StringDTypeObject *fill_descr = (PyArray_StringDTypeObject *)given_descrs[2];
1517+
int out_coerce = input_descr->coerce && fill_descr->coerce;
1518+
PyObject *out_na_object = NULL;
15911519

1592-
int eq_res = _eq_comparison(input_descr->coerce, fill_descr->coerce,
1593-
input_descr->na_object, fill_descr->na_object);
1594-
1595-
if (eq_res < 0) {
1596-
return (NPY_CASTING)-1;
1597-
}
1598-
1599-
if (eq_res != 1) {
1600-
PyErr_SetString(PyExc_TypeError,
1601-
"Can only do text justification operations with equal"
1602-
"StringDType instances.");
1520+
if (stringdtype_compatible_na(
1521+
input_descr->na_object, fill_descr->na_object, &out_na_object) == -1) {
16031522
return (NPY_CASTING)-1;
16041523
}
16051524

@@ -1614,8 +1533,7 @@ center_ljust_rjust_resolve_descriptors(
16141533

16151534
if (given_descrs[3] == NULL) {
16161535
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1617-
((PyArray_StringDTypeObject *)given_descrs[1])->na_object,
1618-
((PyArray_StringDTypeObject *)given_descrs[1])->coerce);
1536+
out_na_object, out_coerce);
16191537

16201538
if (out_descr == NULL) {
16211539
return (NPY_CASTING)-1;
@@ -1888,6 +1806,7 @@ zfill_strided_loop(PyArrayMethod_Context *context,
18881806
return -1;
18891807
}
18901808

1809+
18911810
static NPY_CASTING
18921811
string_partition_resolve_descriptors(
18931812
PyArrayMethodObject *self,
@@ -1901,14 +1820,25 @@ string_partition_resolve_descriptors(
19011820
"currently support the 'out' keyword", self->name);
19021821
return (NPY_CASTING)-1;
19031822
}
1904-
for (int i=0; i<2; i++) {
1905-
Py_INCREF(given_descrs[i]);
1906-
loop_descrs[i] = given_descrs[i];
1823+
1824+
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
1825+
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
1826+
int out_coerce = descr1->coerce && descr2->coerce;
1827+
PyObject *out_na_object = NULL;
1828+
1829+
if (stringdtype_compatible_na(
1830+
descr1->na_object, descr2->na_object, &out_na_object) == -1) {
1831+
return (NPY_CASTING)-1;
19071832
}
1908-
PyArray_StringDTypeObject *adescr = (PyArray_StringDTypeObject *)given_descrs[0];
1833+
1834+
Py_INCREF(given_descrs[0]);
1835+
loop_descrs[0] = given_descrs[0];
1836+
Py_INCREF(given_descrs[1]);
1837+
loop_descrs[1] = given_descrs[1];
1838+
19091839
for (int i=2; i<5; i++) {
19101840
loop_descrs[i] = (PyArray_Descr *)new_stringdtype_instance(
1911-
adescr->na_object, adescr->coerce);
1841+
out_na_object, out_coerce);
19121842
if (loop_descrs[i] == NULL) {
19131843
return (NPY_CASTING)-1;
19141844
}
@@ -2655,7 +2585,7 @@ init_stringdtype_ufuncs(PyObject *umath)
26552585

26562586
for (int i=0; i<3; i++) {
26572587
if (init_ufunc(umath, strip_chars_names[i], strip_chars_dtypes,
2658-
&strip_chars_resolve_descriptors,
2588+
&binary_resolve_descriptors,
26592589
&string_lrstrip_chars_strided_loop,
26602590
2, 1, NPY_NO_CASTING, (NPY_ARRAYMETHOD_FLAGS) 0,
26612591
&strip_types[i]) < 0) {

0 commit comments

Comments
 (0)