Skip to content

Commit 4e6d2bf

Browse files
authored
ENH: add support for nan-like null strings in string replace (numpy#26355)
This fixes an issue similar to the one fixed by numpy#26353. In particular, right now np.strings.replace calls the count ufunc to get the number of replacements. This is necessary for fixed-width strings, but it turns out to make it impossible to support null strings in replace. I went ahead and instead found the replacement counts inline in the ufunc loop. This lets me add support for nan-like null strings, which it turns out pandas needs.
1 parent 05f8351 commit 4e6d2bf

File tree

3 files changed

+72
-33
lines changed

3 files changed

+72
-33
lines changed

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,9 @@ string_replace_strided_loop(
13001300

13011301
PyArray_StringDTypeObject *descr0 =
13021302
(PyArray_StringDTypeObject *)context->descriptors[0];
1303+
int has_null = descr0->na_object != NULL;
13031304
int has_string_na = descr0->has_string_na;
1305+
int has_nan_na = descr0->has_nan_na;
13041306
const npy_static_string *default_string = &descr0->default_string;
13051307

13061308

@@ -1330,11 +1332,29 @@ string_replace_strided_loop(
13301332
goto fail;
13311333
}
13321334
else if (i1_isnull || i2_isnull || i3_isnull) {
1333-
if (!has_string_na) {
1334-
npy_gil_error(PyExc_ValueError,
1335-
"Null values are not supported as replacement arguments "
1336-
"for replace");
1337-
goto fail;
1335+
if (has_null && !has_string_na) {
1336+
if (i2_isnull || i3_isnull) {
1337+
npy_gil_error(PyExc_ValueError,
1338+
"Null values are not supported as search "
1339+
"patterns or replacement strings for "
1340+
"replace");
1341+
goto fail;
1342+
}
1343+
else if (i1_isnull) {
1344+
if (has_nan_na) {
1345+
if (NpyString_pack_null(oallocator, ops) < 0) {
1346+
npy_gil_error(PyExc_MemoryError,
1347+
"Failed to deallocate string in replace");
1348+
goto fail;
1349+
}
1350+
goto next_step;
1351+
}
1352+
else {
1353+
npy_gil_error(PyExc_ValueError,
1354+
"Only string or NaN-like null strings can "
1355+
"be used as search strings for replace");
1356+
}
1357+
}
13381358
}
13391359
else {
13401360
if (i1_isnull) {
@@ -1349,32 +1369,51 @@ string_replace_strided_loop(
13491369
}
13501370
}
13511371

1352-
// conservatively overallocate
1353-
// TODO check overflow
1354-
size_t max_size;
1355-
if (i2s.size == 0) {
1356-
// interleaving
1357-
max_size = i1s.size + (i1s.size + 1)*(i3s.size);
1358-
}
1359-
else {
1360-
// replace i2 with i3
1361-
max_size = i1s.size * (i3s.size/i2s.size + 1);
1362-
}
1363-
char *new_buf = (char *)PyMem_RawCalloc(max_size, 1);
1364-
Buffer<ENCODING::UTF8> buf1((char *)i1s.buf, i1s.size);
1365-
Buffer<ENCODING::UTF8> buf2((char *)i2s.buf, i2s.size);
1366-
Buffer<ENCODING::UTF8> buf3((char *)i3s.buf, i3s.size);
1367-
Buffer<ENCODING::UTF8> outbuf(new_buf, max_size);
1372+
{
1373+
Buffer<ENCODING::UTF8> buf1((char *)i1s.buf, i1s.size);
1374+
Buffer<ENCODING::UTF8> buf2((char *)i2s.buf, i2s.size);
13681375

1369-
size_t new_buf_size = string_replace(
1370-
buf1, buf2, buf3, *(npy_int64 *)in4, outbuf);
1376+
npy_int64 in_count = *(npy_int64*)in4;
1377+
if (in_count == -1) {
1378+
in_count = NPY_MAX_INT64;
1379+
}
13711380

1372-
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
1373-
npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace");
1374-
goto fail;
1375-
}
1381+
npy_int64 found_count = string_count<ENCODING::UTF8>(
1382+
buf1, buf2, 0, NPY_MAX_INT64);
1383+
if (found_count < 0) {
1384+
goto fail;
1385+
}
13761386

1377-
PyMem_RawFree(new_buf);
1387+
npy_intp count = Py_MIN(in_count, found_count);
1388+
1389+
Buffer<ENCODING::UTF8> buf3((char *)i3s.buf, i3s.size);
1390+
1391+
// conservatively overallocate
1392+
// TODO check overflow
1393+
size_t max_size;
1394+
if (i2s.size == 0) {
1395+
// interleaving
1396+
max_size = i1s.size + (i1s.size + 1)*(i3s.size);
1397+
}
1398+
else {
1399+
// replace i2 with i3
1400+
size_t change = i2s.size >= i3s.size ? 0 : i3s.size - i2s.size;
1401+
max_size = i1s.size + count * change;
1402+
}
1403+
char *new_buf = (char *)PyMem_RawCalloc(max_size, 1);
1404+
Buffer<ENCODING::UTF8> outbuf(new_buf, max_size);
1405+
1406+
size_t new_buf_size = string_replace(
1407+
buf1, buf2, buf3, count, outbuf);
1408+
1409+
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
1410+
npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace");
1411+
goto fail;
1412+
}
1413+
1414+
PyMem_RawFree(new_buf);
1415+
}
1416+
next_step:
13781417

13791418
in1 += strides[0];
13801419
in2 += strides[1];

numpy/_core/strings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,15 +1153,15 @@ def replace(a, old, new, count=-1):
11531153
a_dt = arr.dtype
11541154
old = np.asanyarray(old, dtype=getattr(old, 'dtype', a_dt))
11551155
new = np.asanyarray(new, dtype=getattr(new, 'dtype', a_dt))
1156+
count = np.asanyarray(count)
1157+
1158+
if arr.dtype.char == "T":
1159+
return _replace(arr, old, new, count)
11561160

11571161
max_int64 = np.iinfo(np.int64).max
11581162
counts = _count_ufunc(arr, old, 0, max_int64)
1159-
count = np.asanyarray(count)
11601163
counts = np.where(count < 0, counts, np.minimum(counts, count))
11611164

1162-
if arr.dtype.char == "T":
1163-
return _replace(arr, old, new, counts)
1164-
11651165
buffersizes = str_len(arr) + counts * (str_len(new) - str_len(old))
11661166
out_dtype = f"{arr.dtype.char}{buffersizes.max()}"
11671167
out = np.empty_like(arr, shape=buffersizes.shape, dtype=out_dtype)

numpy/_core/tests/test_stringdtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,7 @@ def test_unary(string_array, unicode_array, function_name):
12181218
"strip",
12191219
"lstrip",
12201220
"rstrip",
1221+
"replace"
12211222
"zfill",
12221223
]
12231224

@@ -1230,7 +1231,6 @@ def test_unary(string_array, unicode_array, function_name):
12301231
"count",
12311232
"find",
12321233
"rfind",
1233-
"replace",
12341234
]
12351235

12361236
SUPPORTS_NULLS = (

0 commit comments

Comments
 (0)