Skip to content

Commit 0b723bc

Browse files
committed
BUG: Fix string addition promoter to work with dtype=
Also makes it reject forced unicode selection, since that doens't work.
1 parent a4596d7 commit 0b723bc

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,25 @@ all_strings_promoter(PyObject *NPY_UNUSED(ufunc),
10281028
PyArray_DTypeMeta *const signature[],
10291029
PyArray_DTypeMeta *new_op_dtypes[])
10301030
{
1031+
if (op_dtypes[0] != &PyArray_StringDType &&
1032+
op_dtypes[1] != &PyArray_StringDType &&
1033+
op_dtypes[2] != &PyArray_StringDType) {
1034+
/*
1035+
* This promoter was triggered with only unicode arguments, so use
1036+
* unicode. This can happen due to `dtype=` support which sets the
1037+
* output DType/signature.
1038+
*/
1039+
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_UnicodeDType);
1040+
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_UnicodeDType);
1041+
new_op_dtypes[2] = NPY_DT_NewRef(&PyArray_UnicodeDType);
1042+
return 0;
1043+
}
1044+
if (signature[0] == &PyArray_UnicodeDType &&
1045+
signature[1] == &PyArray_UnicodeDType &&
1046+
signature[2] == &PyArray_UnicodeDType) {
1047+
/* Unicode forced, but didn't override a string input: invalid */
1048+
return -1;
1049+
}
10311050
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_StringDType);
10321051
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_StringDType);
10331052
new_op_dtypes[2] = NPY_DT_NewRef(&PyArray_StringDType);
@@ -2532,6 +2551,17 @@ init_stringdtype_ufuncs(PyObject *umath)
25322551
return -1;
25332552
}
25342553

2554+
PyArray_DTypeMeta *out_strings_promoter_dtypes[] = {
2555+
&PyArray_UnicodeDType,
2556+
&PyArray_UnicodeDType,
2557+
&PyArray_StringDType,
2558+
};
2559+
2560+
if (add_promoter(umath, "add", out_strings_promoter_dtypes, 3,
2561+
all_strings_promoter) < 0) {
2562+
return -1;
2563+
}
2564+
25352565
INIT_MULTIPLY(Int64, int64);
25362566
INIT_MULTIPLY(UInt64, uint64);
25372567

numpy/_core/tests/test_stringdtype.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,22 @@ def test_add_promoter(string_list):
828828
assert_array_equal(op + arr, lresult)
829829
assert_array_equal(arr + op, rresult)
830830

831+
# The promoter should be able to handle things if users pass `dtype=`
832+
res = np.add("hello", string_list, dtype=StringDType)
833+
assert res.dtype == StringDType()
834+
835+
# The promoter should not kick in if users override the input,
836+
# which means arr is cast, this fails because of the unknown length.
837+
with pytest.raises(TypeError, match="cannot cast dtype"):
838+
np.add(arr, "add", signature=("U", "U", None), casting="unsafe")
839+
840+
# But it must simply reject the following:
841+
with pytest.raises(TypeError, match=".*did not contain a loop"):
842+
np.add(arr, "add", signature=(None, "U", None))
843+
844+
with pytest.raises(TypeError, match=".*did not contain a loop"):
845+
np.add("a", "b", signature=("U", "U", StringDType))
846+
831847

832848
def test_add_promoter_reduce():
833849
# Exact TypeError could change, but ensure StringDtype doesn't match

0 commit comments

Comments
 (0)