Skip to content

Commit e86c581

Browse files
authored
Merge pull request numpy#26392 from ngoldbaum/strip-null-support
BUG: support nan-like null strings in [l,r]strip
2 parents 2a9b913 + e438a86 commit e86c581

File tree

2 files changed

+94
-30
lines changed

2 files changed

+94
-30
lines changed

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,7 @@ string_lrstrip_chars_strided_loop(
10461046
PyArray_StringDTypeObject *s1descr = (PyArray_StringDTypeObject *)context->descriptors[0];
10471047
int has_null = s1descr->na_object != NULL;
10481048
int has_string_na = s1descr->has_string_na;
1049+
int has_nan_na = s1descr->has_nan_na;
10491050

10501051
const npy_static_string *default_string = &s1descr->default_string;
10511052
npy_intp N = dimensions[0];
@@ -1072,28 +1073,47 @@ string_lrstrip_chars_strided_loop(
10721073
s2 = *default_string;
10731074
}
10741075
}
1076+
else if (has_nan_na) {
1077+
if (s2_isnull) {
1078+
npy_gil_error(PyExc_ValueError,
1079+
"Cannot use a null string that is not a "
1080+
"string as the %s delimiter", ufunc_name);
1081+
}
1082+
if (s1_isnull) {
1083+
if (NpyString_pack_null(oallocator, ops) < 0) {
1084+
npy_gil_error(PyExc_MemoryError,
1085+
"Failed to deallocate string in %s",
1086+
ufunc_name);
1087+
goto fail;
1088+
}
1089+
goto next_step;
1090+
}
1091+
}
10751092
else {
10761093
npy_gil_error(PyExc_ValueError,
1077-
"Cannot strip null values that are not strings");
1094+
"Can only strip null values that are strings "
1095+
"or NaN-like values");
10781096
goto fail;
10791097
}
10801098
}
1099+
{
1100+
char *new_buf = (char *)PyMem_RawCalloc(s1.size, 1);
1101+
Buffer<ENCODING::UTF8> buf1((char *)s1.buf, s1.size);
1102+
Buffer<ENCODING::UTF8> buf2((char *)s2.buf, s2.size);
1103+
Buffer<ENCODING::UTF8> outbuf(new_buf, s1.size);
1104+
size_t new_buf_size = string_lrstrip_chars
1105+
(buf1, buf2, outbuf, striptype);
10811106

1107+
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
1108+
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
1109+
ufunc_name);
1110+
PyMem_RawFree(new_buf);
1111+
goto fail;
1112+
}
10821113

1083-
char *new_buf = (char *)PyMem_RawCalloc(s1.size, 1);
1084-
Buffer<ENCODING::UTF8> buf1((char *)s1.buf, s1.size);
1085-
Buffer<ENCODING::UTF8> buf2((char *)s2.buf, s2.size);
1086-
Buffer<ENCODING::UTF8> outbuf(new_buf, s1.size);
1087-
size_t new_buf_size = string_lrstrip_chars
1088-
(buf1, buf2, outbuf, striptype);
1089-
1090-
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
1091-
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
1092-
ufunc_name);
1093-
goto fail;
1114+
PyMem_RawFree(new_buf);
10941115
}
1095-
1096-
PyMem_RawFree(new_buf);
1116+
next_step:
10971117

10981118
in1 += strides[0];
10991119
in2 += strides[1];
@@ -1150,8 +1170,9 @@ string_lrstrip_whitespace_strided_loop(
11501170
const char *ufunc_name = ((PyUFuncObject *)context->caller)->name;
11511171
STRIPTYPE striptype = *(STRIPTYPE *)context->method->static_data;
11521172
PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)context->descriptors[0];
1153-
int has_string_na = descr->has_string_na;
11541173
int has_null = descr->na_object != NULL;
1174+
int has_string_na = descr->has_string_na;
1175+
int has_nan_na = descr->has_nan_na;
11551176
const npy_static_string *default_string = &descr->default_string;
11561177

11571178
npy_string_allocator *allocators[2] = {};
@@ -1181,26 +1202,39 @@ string_lrstrip_whitespace_strided_loop(
11811202
if (has_string_na || !has_null) {
11821203
s = *default_string;
11831204
}
1205+
else if (has_nan_na) {
1206+
if (NpyString_pack_null(oallocator, ops) < 0) {
1207+
npy_gil_error(PyExc_MemoryError,
1208+
"Failed to deallocate string in %s",
1209+
ufunc_name);
1210+
goto fail;
1211+
}
1212+
goto next_step;
1213+
}
11841214
else {
11851215
npy_gil_error(PyExc_ValueError,
1186-
"Cannot strip null values that are not strings");
1216+
"Can only strip null values that are strings or "
1217+
"NaN-like values");
11871218
goto fail;
11881219
}
11891220
}
1221+
{
1222+
char *new_buf = (char *)PyMem_RawCalloc(s.size, 1);
1223+
Buffer<ENCODING::UTF8> buf((char *)s.buf, s.size);
1224+
Buffer<ENCODING::UTF8> outbuf(new_buf, s.size);
1225+
size_t new_buf_size = string_lrstrip_whitespace(
1226+
buf, outbuf, striptype);
11901227

1191-
char *new_buf = (char *)PyMem_RawCalloc(s.size, 1);
1192-
Buffer<ENCODING::UTF8> buf((char *)s.buf, s.size);
1193-
Buffer<ENCODING::UTF8> outbuf(new_buf, s.size);
1194-
size_t new_buf_size = string_lrstrip_whitespace(
1195-
buf, outbuf, striptype);
1228+
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
1229+
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
1230+
ufunc_name);
1231+
goto fail;
1232+
}
11961233

1197-
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
1198-
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
1199-
ufunc_name);
1200-
goto fail;
1234+
PyMem_RawFree(new_buf);
12011235
}
12021236

1203-
PyMem_RawFree(new_buf);
1237+
next_step:
12041238

12051239
in += strides[0];
12061240
out += strides[1];

numpy/_core/tests/test_stringdtype.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,13 @@ def unicode_array():
10801080
"capitalize",
10811081
"expandtabs",
10821082
"lower",
1083-
"splitlines" "swapcase" "title" "upper",
1083+
"lstrip",
1084+
"rstrip",
1085+
"splitlines",
1086+
"strip",
1087+
"swapcase",
1088+
"title",
1089+
"upper",
10841090
]
10851091

10861092
BOOL_OUTPUT_FUNCTIONS = [
@@ -1107,7 +1113,10 @@ def unicode_array():
11071113
"istitle",
11081114
"isupper",
11091115
"lower",
1116+
"lstrip",
1117+
"rstrip",
11101118
"splitlines",
1119+
"strip",
11111120
"swapcase",
11121121
"title",
11131122
"upper",
@@ -1129,10 +1138,20 @@ def unicode_array():
11291138
"upper",
11301139
]
11311140

1141+
ONLY_IN_NP_CHAR = [
1142+
"join",
1143+
"split",
1144+
"rsplit",
1145+
"splitlines"
1146+
]
1147+
11321148

11331149
@pytest.mark.parametrize("function_name", UNARY_FUNCTIONS)
11341150
def test_unary(string_array, unicode_array, function_name):
1135-
func = getattr(np.char, function_name)
1151+
if function_name in ONLY_IN_NP_CHAR:
1152+
func = getattr(np.char, function_name)
1153+
else:
1154+
func = getattr(np.strings, function_name)
11361155
dtype = string_array.dtype
11371156
sres = func(string_array)
11381157
ures = func(unicode_array)
@@ -1173,6 +1192,10 @@ def test_unary(string_array, unicode_array, function_name):
11731192
with pytest.raises(ValueError):
11741193
func(na_arr)
11751194
return
1195+
if not (is_nan or is_str):
1196+
with pytest.raises(ValueError):
1197+
func(na_arr)
1198+
return
11761199
res = func(na_arr)
11771200
if is_nan and function_name in NAN_PRESERVING_FUNCTIONS:
11781201
assert res[0] is dtype.na_object
@@ -1197,13 +1220,17 @@ def test_unary(string_array, unicode_array, function_name):
11971220
("index", (None, "e")),
11981221
("join", ("-", None)),
11991222
("ljust", (None, 12)),
1223+
("lstrip", (None, "A")),
12001224
("partition", (None, "A")),
12011225
("replace", (None, "A", "B")),
12021226
("rfind", (None, "A")),
12031227
("rindex", (None, "e")),
12041228
("rjust", (None, 12)),
1229+
("rsplit", (None, "A")),
1230+
("rstrip", (None, "A")),
12051231
("rpartition", (None, "A")),
12061232
("split", (None, "A")),
1233+
("strip", (None, "A")),
12071234
("startswith", (None, "A")),
12081235
("zfill", (None, 12)),
12091236
]
@@ -1260,10 +1287,13 @@ def call_func(func, args, array, sanitize=True):
12601287

12611288
@pytest.mark.parametrize("function_name, args", BINARY_FUNCTIONS)
12621289
def test_binary(string_array, unicode_array, function_name, args):
1263-
func = getattr(np.char, function_name)
1290+
if function_name in ONLY_IN_NP_CHAR:
1291+
func = getattr(np.char, function_name)
1292+
else:
1293+
func = getattr(np.strings, function_name)
12641294
sres = call_func(func, args, string_array)
12651295
ures = call_func(func, args, unicode_array, sanitize=False)
1266-
if sres.dtype == StringDType():
1296+
if not isinstance(sres, tuple) and sres.dtype == StringDType():
12671297
ures = ures.astype(StringDType())
12681298
assert_array_equal(sres, ures)
12691299

0 commit comments

Comments
 (0)