@@ -1300,7 +1300,9 @@ string_replace_strided_loop(
1300
1300
1301
1301
PyArray_StringDTypeObject *descr0 =
1302
1302
(PyArray_StringDTypeObject *)context->descriptors [0 ];
1303
+ int has_null = descr0->na_object != NULL ;
1303
1304
int has_string_na = descr0->has_string_na ;
1305
+ int has_nan_na = descr0->has_nan_na ;
1304
1306
const npy_static_string *default_string = &descr0->default_string ;
1305
1307
1306
1308
@@ -1330,11 +1332,29 @@ string_replace_strided_loop(
1330
1332
goto fail;
1331
1333
}
1332
1334
else if (i1_isnull || i2_isnull || i3_isnull) {
1333
- if (!has_string_na) {
1334
- npy_gil_error (PyExc_ValueError,
1335
- " Null values are not supported as replacement arguments "
1336
- " for replace" );
1337
- goto fail;
1335
+ if (has_null && !has_string_na) {
1336
+ if (i2_isnull || i3_isnull) {
1337
+ npy_gil_error (PyExc_ValueError,
1338
+ " Null values are not supported as search "
1339
+ " patterns or replacement strings for "
1340
+ " replace" );
1341
+ goto fail;
1342
+ }
1343
+ else if (i1_isnull) {
1344
+ if (has_nan_na) {
1345
+ if (NpyString_pack_null (oallocator, ops) < 0 ) {
1346
+ npy_gil_error (PyExc_MemoryError,
1347
+ " Failed to deallocate string in replace" );
1348
+ goto fail;
1349
+ }
1350
+ goto next_step;
1351
+ }
1352
+ else {
1353
+ npy_gil_error (PyExc_ValueError,
1354
+ " Only string or NaN-like null strings can "
1355
+ " be used as search strings for replace" );
1356
+ }
1357
+ }
1338
1358
}
1339
1359
else {
1340
1360
if (i1_isnull) {
@@ -1349,32 +1369,51 @@ string_replace_strided_loop(
1349
1369
}
1350
1370
}
1351
1371
1352
- // conservatively overallocate
1353
- // TODO check overflow
1354
- size_t max_size;
1355
- if (i2s.size == 0 ) {
1356
- // interleaving
1357
- max_size = i1s.size + (i1s.size + 1 )*(i3s.size );
1358
- }
1359
- else {
1360
- // replace i2 with i3
1361
- max_size = i1s.size * (i3s.size /i2s.size + 1 );
1362
- }
1363
- char *new_buf = (char *)PyMem_RawCalloc (max_size, 1 );
1364
- Buffer<ENCODING::UTF8> buf1 ((char *)i1s.buf , i1s.size );
1365
- Buffer<ENCODING::UTF8> buf2 ((char *)i2s.buf , i2s.size );
1366
- Buffer<ENCODING::UTF8> buf3 ((char *)i3s.buf , i3s.size );
1367
- Buffer<ENCODING::UTF8> outbuf (new_buf, max_size);
1372
+ {
1373
+ Buffer<ENCODING::UTF8> buf1 ((char *)i1s.buf , i1s.size );
1374
+ Buffer<ENCODING::UTF8> buf2 ((char *)i2s.buf , i2s.size );
1368
1375
1369
- size_t new_buf_size = string_replace (
1370
- buf1, buf2, buf3, *(npy_int64 *)in4, outbuf);
1376
+ npy_int64 in_count = *(npy_int64*)in4;
1377
+ if (in_count == -1 ) {
1378
+ in_count = NPY_MAX_INT64;
1379
+ }
1371
1380
1372
- if (NpyString_pack (oallocator, ops, new_buf, new_buf_size) < 0 ) {
1373
- npy_gil_error (PyExc_MemoryError, " Failed to pack string in replace" );
1374
- goto fail;
1375
- }
1381
+ npy_int64 found_count = string_count<ENCODING::UTF8>(
1382
+ buf1, buf2, 0 , NPY_MAX_INT64);
1383
+ if (found_count < 0 ) {
1384
+ goto fail;
1385
+ }
1376
1386
1377
- PyMem_RawFree (new_buf);
1387
+ npy_intp count = Py_MIN (in_count, found_count);
1388
+
1389
+ Buffer<ENCODING::UTF8> buf3 ((char *)i3s.buf , i3s.size );
1390
+
1391
+ // conservatively overallocate
1392
+ // TODO check overflow
1393
+ size_t max_size;
1394
+ if (i2s.size == 0 ) {
1395
+ // interleaving
1396
+ max_size = i1s.size + (i1s.size + 1 )*(i3s.size );
1397
+ }
1398
+ else {
1399
+ // replace i2 with i3
1400
+ size_t change = i2s.size >= i3s.size ? 0 : i3s.size - i2s.size ;
1401
+ max_size = i1s.size + count * change;
1402
+ }
1403
+ char *new_buf = (char *)PyMem_RawCalloc (max_size, 1 );
1404
+ Buffer<ENCODING::UTF8> outbuf (new_buf, max_size);
1405
+
1406
+ size_t new_buf_size = string_replace (
1407
+ buf1, buf2, buf3, count, outbuf);
1408
+
1409
+ if (NpyString_pack (oallocator, ops, new_buf, new_buf_size) < 0 ) {
1410
+ npy_gil_error (PyExc_MemoryError, " Failed to pack string in replace" );
1411
+ goto fail;
1412
+ }
1413
+
1414
+ PyMem_RawFree (new_buf);
1415
+ }
1416
+ next_step:
1378
1417
1379
1418
in1 += strides[0 ];
1380
1419
in2 += strides[1 ];
0 commit comments