Skip to content

Commit 9620e00

Browse files
String dtype: propagate NaNs as False in predicate methods (eg .str.startswith)
1 parent 360597c commit 9620e00

File tree

8 files changed

+206
-100
lines changed

8 files changed

+206
-100
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,7 +2291,12 @@ def _str_count(self, pat: str, flags: int = 0) -> Self:
22912291
return type(self)(pc.count_substring_regex(self._pa_array, pat))
22922292

22932293
def _str_contains(
2294-
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
2294+
self,
2295+
pat,
2296+
case: bool = True,
2297+
flags: int = 0,
2298+
na=lib.no_default,
2299+
regex: bool = True,
22952300
) -> Self:
22962301
if flags:
22972302
raise NotImplementedError(f"contains not implemented with {flags=}")
@@ -2301,11 +2306,11 @@ def _str_contains(
23012306
else:
23022307
pa_contains = pc.match_substring
23032308
result = pa_contains(self._pa_array, pat, ignore_case=not case)
2304-
if not isna(na):
2309+
if na is not lib.no_default and not isna(na):
23052310
result = result.fill_null(na)
23062311
return type(self)(result)
23072312

2308-
def _str_startswith(self, pat: str | tuple[str, ...], na=None) -> Self:
2313+
def _str_startswith(self, pat: str | tuple[str, ...], na=lib.no_default) -> Self:
23092314
if isinstance(pat, str):
23102315
result = pc.starts_with(self._pa_array, pattern=pat)
23112316
else:
@@ -2318,7 +2323,7 @@ def _str_startswith(self, pat: str | tuple[str, ...], na=None) -> Self:
23182323

23192324
for p in pat[1:]:
23202325
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
2321-
if not isna(na):
2326+
if na is not lib.no_default and not isna(na):
23222327
result = result.fill_null(na)
23232328
return type(self)(result)
23242329

@@ -2335,7 +2340,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na=None) -> Self:
23352340

23362341
for p in pat[1:]:
23372342
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
2338-
if not isna(na):
2343+
if na is not lib.no_default and not isna(na):
23392344
result = result.fill_null(na)
23402345
return type(self)(result)
23412346

@@ -2374,14 +2379,18 @@ def _str_repeat(self, repeats: int | Sequence[int]) -> Self:
23742379
return type(self)(pc.binary_repeat(self._pa_array, repeats))
23752380

23762381
def _str_match(
2377-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
2382+
self,
2383+
pat: str,
2384+
case: bool = True,
2385+
flags: int = 0,
2386+
na: Scalar | None = lib.no_default,
23782387
) -> Self:
23792388
if not pat.startswith("^"):
23802389
pat = f"^{pat}"
23812390
return self._str_contains(pat, case, flags, na, regex=True)
23822391

23832392
def _str_fullmatch(
2384-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
2393+
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = lib.no_default
23852394
) -> Self:
23862395
if not pat.endswith("$") or pat.endswith("\\$"):
23872396
pat = f"{pat}$"

pandas/core/arrays/string_.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,11 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
351351
return cls._from_sequence(scalars, dtype=dtype)
352352

353353
def _str_map(
354-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
354+
self,
355+
f,
356+
na_value=lib.no_default,
357+
dtype: Dtype | None = None,
358+
convert: bool = True,
355359
):
356360
if self.dtype.na_value is np.nan:
357361
return self._str_map_nan_semantics(f, na_value=na_value, dtype=dtype)
@@ -360,7 +364,7 @@ def _str_map(
360364

361365
if dtype is None:
362366
dtype = self.dtype
363-
if na_value is None:
367+
if na_value is lib.no_default:
364368
na_value = self.dtype.na_value
365369

366370
mask = isna(self)
@@ -429,11 +433,16 @@ def _str_map_str_or_object(
429433
# -> We don't know the result type. E.g. `.get` can return anything.
430434
return lib.map_infer_mask(arr, f, mask.view("uint8"))
431435

432-
def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
436+
def _str_map_nan_semantics(
437+
self, f, na_value=lib.no_default, dtype: Dtype | None = None
438+
):
433439
if dtype is None:
434440
dtype = self.dtype
435-
if na_value is None:
436-
na_value = self.dtype.na_value
441+
if na_value is lib.no_default:
442+
if is_bool_dtype(dtype):
443+
na_value = False
444+
else:
445+
na_value = self.dtype.na_value
437446

438447
mask = isna(self)
439448
arr = np.asarray(self)

pandas/core/arrays/string_arrow.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,16 @@ def insert(self, loc: int, item) -> ArrowStringArray:
221221
raise TypeError("Scalar must be NA or str")
222222
return super().insert(loc, item)
223223

224-
def _result_converter(self, values, na=None):
224+
def _predicate_result_converter(self, values, na=lib.no_default):
225225
if self.dtype.na_value is np.nan:
226-
if not isna(na):
226+
if na is lib.no_default:
227+
na_value = False
228+
elif not isna(na):
227229
values = values.fill_null(bool(na))
228-
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
230+
na_value = lib.no_default
231+
else:
232+
na_value = np.nan
233+
return ArrowExtensionArray(values).to_numpy(na_value=na_value)
229234
return BooleanDtype().__from_arrow__(values)
230235

231236
def _maybe_convert_setitem_value(self, value):
@@ -282,7 +287,12 @@ def astype(self, dtype, copy: bool = True):
282287
_str_map = BaseStringArray._str_map
283288

284289
def _str_contains(
285-
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
290+
self,
291+
pat,
292+
case: bool = True,
293+
flags: int = 0,
294+
na=lib.no_default,
295+
regex: bool = True,
286296
):
287297
if flags:
288298
if get_option("mode.performance_warnings"):
@@ -293,12 +303,18 @@ def _str_contains(
293303
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
294304
else:
295305
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
296-
result = self._result_converter(result, na=na)
297-
if not isna(na):
306+
result = self._predicate_result_converter(result, na=na)
307+
if (
308+
self.dtype.na_value is libmissing.NA
309+
and na is not lib.no_default
310+
and not isna(na)
311+
):
298312
result[isna(result)] = bool(na)
299313
return result
300314

301-
def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
315+
def _str_startswith(
316+
self, pat: str | tuple[str, ...], na: Scalar | None = lib.no_default
317+
):
302318
if isinstance(pat, str):
303319
result = pc.starts_with(self._pa_array, pattern=pat)
304320
else:
@@ -313,9 +329,13 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
313329

314330
for p in pat[1:]:
315331
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
316-
if not isna(na):
332+
if (
333+
self.dtype.na_value is libmissing.NA
334+
and na is not lib.no_default
335+
and not isna(na)
336+
):
317337
result = result.fill_null(na)
318-
return self._result_converter(result)
338+
return self._predicate_result_converter(result)
319339

320340
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
321341
if isinstance(pat, str):
@@ -332,9 +352,13 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
332352

333353
for p in pat[1:]:
334354
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
335-
if not isna(na):
355+
if (
356+
self.dtype.na_value is libmissing.NA
357+
and na is not lib.no_default
358+
and not isna(na)
359+
):
336360
result = result.fill_null(na)
337-
return self._result_converter(result)
361+
return self._predicate_result_converter(result)
338362

339363
def _str_replace(
340364
self,
@@ -361,14 +385,18 @@ def _str_repeat(self, repeats: int | Sequence[int]):
361385
return type(self)(pc.binary_repeat(self._pa_array, repeats))
362386

363387
def _str_match(
364-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
388+
self,
389+
pat: str,
390+
case: bool = True,
391+
flags: int = 0,
392+
na: Scalar | None = lib.no_default,
365393
):
366394
if not pat.startswith("^"):
367395
pat = f"^{pat}"
368396
return self._str_contains(pat, case, flags, na, regex=True)
369397

370398
def _str_fullmatch(
371-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
399+
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = lib.no_default
372400
):
373401
if not pat.endswith("$") or pat.endswith("\\$"):
374402
pat = f"{pat}$"
@@ -389,39 +417,39 @@ def _str_slice(
389417

390418
def _str_isalnum(self):
391419
result = pc.utf8_is_alnum(self._pa_array)
392-
return self._result_converter(result)
420+
return self._predicate_result_converter(result)
393421

394422
def _str_isalpha(self):
395423
result = pc.utf8_is_alpha(self._pa_array)
396-
return self._result_converter(result)
424+
return self._predicate_result_converter(result)
397425

398426
def _str_isdecimal(self):
399427
result = pc.utf8_is_decimal(self._pa_array)
400-
return self._result_converter(result)
428+
return self._predicate_result_converter(result)
401429

402430
def _str_isdigit(self):
403431
result = pc.utf8_is_digit(self._pa_array)
404-
return self._result_converter(result)
432+
return self._predicate_result_converter(result)
405433

406434
def _str_islower(self):
407435
result = pc.utf8_is_lower(self._pa_array)
408-
return self._result_converter(result)
436+
return self._predicate_result_converter(result)
409437

410438
def _str_isnumeric(self):
411439
result = pc.utf8_is_numeric(self._pa_array)
412-
return self._result_converter(result)
440+
return self._predicate_result_converter(result)
413441

414442
def _str_isspace(self):
415443
result = pc.utf8_is_space(self._pa_array)
416-
return self._result_converter(result)
444+
return self._predicate_result_converter(result)
417445

418446
def _str_istitle(self):
419447
result = pc.utf8_is_title(self._pa_array)
420-
return self._result_converter(result)
448+
return self._predicate_result_converter(result)
421449

422450
def _str_isupper(self):
423451
result = pc.utf8_is_upper(self._pa_array)
424-
return self._result_converter(result)
452+
return self._predicate_result_converter(result)
425453

426454
def _str_len(self):
427455
result = pc.utf8_length(self._pa_array)

pandas/core/strings/accessor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,12 @@ def join(self, sep: str):
12221222

12231223
@forbid_nonstring_types(["bytes"])
12241224
def contains(
1225-
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
1225+
self,
1226+
pat,
1227+
case: bool = True,
1228+
flags: int = 0,
1229+
na=lib.no_default,
1230+
regex: bool = True,
12261231
):
12271232
r"""
12281233
Test if pattern or regex is contained within a string of a Series or Index.
@@ -1359,7 +1364,7 @@ def contains(
13591364
return self._wrap_result(result, fill_value=na, returns_string=False)
13601365

13611366
@forbid_nonstring_types(["bytes"])
1362-
def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
1367+
def match(self, pat: str, case: bool = True, flags: int = 0, na=lib.no_default):
13631368
"""
13641369
Determine if each string starts with a match of a regular expression.
13651370
@@ -1403,7 +1408,7 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
14031408
return self._wrap_result(result, fill_value=na, returns_string=False)
14041409

14051410
@forbid_nonstring_types(["bytes"])
1406-
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=None):
1411+
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=lib.no_default):
14071412
"""
14081413
Determine if each string entirely matches a regular expression.
14091414
@@ -2581,7 +2586,7 @@ def count(self, pat, flags: int = 0):
25812586

25822587
@forbid_nonstring_types(["bytes"])
25832588
def startswith(
2584-
self, pat: str | tuple[str, ...], na: Scalar | None = None
2589+
self, pat: str | tuple[str, ...], na: Scalar | None = lib.no_default
25852590
) -> Series | Index:
25862591
"""
25872592
Test if the start of each string element matches a pattern.
@@ -2651,7 +2656,7 @@ def startswith(
26512656

26522657
@forbid_nonstring_types(["bytes"])
26532658
def endswith(
2654-
self, pat: str | tuple[str, ...], na: Scalar | None = None
2659+
self, pat: str | tuple[str, ...], na: Scalar | None = lib.no_default
26552660
) -> Series | Index:
26562661
"""
26572662
Test if the end of each string element matches a pattern.

pandas/core/strings/object_array.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ def __len__(self) -> int:
4242
raise NotImplementedError
4343

4444
def _str_map(
45-
self, f, na_value=None, dtype: NpDtype | None = None, convert: bool = True
45+
self,
46+
f,
47+
na_value=lib.no_default,
48+
dtype: NpDtype | None = None,
49+
convert: bool = True,
4650
):
4751
"""
4852
Map a callable over valid elements of the array.
@@ -63,7 +67,7 @@ def _str_map(
6367
"""
6468
if dtype is None:
6569
dtype = np.dtype("object")
66-
if na_value is None:
70+
if na_value is lib.no_default:
6771
na_value = self.dtype.na_value # type: ignore[attr-defined]
6872

6973
if not len(self):
@@ -127,7 +131,12 @@ def _str_pad(
127131
return self._str_map(f)
128132

129133
def _str_contains(
130-
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
134+
self,
135+
pat,
136+
case: bool = True,
137+
flags: int = 0,
138+
na=lib.no_default,
139+
regex: bool = True,
131140
):
132141
if regex:
133142
if not case:
@@ -144,11 +153,11 @@ def _str_contains(
144153
f = lambda x: upper_pat in x.upper()
145154
return self._str_map(f, na, dtype=np.dtype("bool"))
146155

147-
def _str_startswith(self, pat, na=None):
156+
def _str_startswith(self, pat, na=lib.no_default):
148157
f = lambda x: x.startswith(pat)
149158
return self._str_map(f, na_value=na, dtype=np.dtype(bool))
150159

151-
def _str_endswith(self, pat, na=None):
160+
def _str_endswith(self, pat, na=lib.no_default):
152161
f = lambda x: x.endswith(pat)
153162
return self._str_map(f, na_value=na, dtype=np.dtype(bool))
154163

0 commit comments

Comments
 (0)