Skip to content

Commit b9612fc

Browse files
test and fix startswith/endswith
1 parent 7d2a746 commit b9612fc

File tree

3 files changed

+42
-20
lines changed

3 files changed

+42
-20
lines changed

pandas/core/arrays/string_.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,8 @@ def _str_map_nan_semantics(
453453
if is_integer_dtype(dtype):
454454
na_value = 0
455455
else:
456-
na_value = True
456+
# NaN propagates as False
457+
na_value = False
457458

458459
result = lib.map_infer_mask(
459460
arr,
@@ -463,15 +464,13 @@ def _str_map_nan_semantics(
463464
na_value=na_value,
464465
dtype=np.dtype(cast(type, dtype)),
465466
)
466-
if na_value_is_na and mask.any():
467+
if na_value_is_na and is_integer_dtype(dtype) and mask.any():
467468
# TODO: we could alternatively do this check before map_infer_mask
468469
# and adjust the dtype/na_value we pass there. Which is more
469470
# performant?
470-
if is_integer_dtype(dtype):
471-
result = result.astype("float64")
472-
else:
473-
result = result.astype("object")
471+
result = result.astype("float64")
474472
result[mask] = np.nan
473+
475474
return result
476475

477476
else:

pandas/core/arrays/string_arrow.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,16 @@ def insert(self, loc: int, item) -> ArrowStringArray:
223223

224224
def _predicate_result_converter(self, values, na=lib.no_default):
225225
if self.dtype.na_value is np.nan:
226-
na_value: bool | float | lib.NoDefault
226+
na_value: bool | lib.NoDefault
227227
if na is lib.no_default:
228228
na_value = False
229-
elif not isna(na):
230-
values = values.fill_null(bool(na))
229+
elif isna(na):
230+
# NaN propagates as False
231+
values = values.fill_null(False)
231232
na_value = lib.no_default
232233
else:
233-
na_value = np.nan
234+
values = values.fill_null(bool(na))
235+
na_value = lib.no_default
234236
return ArrowExtensionArray(values).to_numpy(na_value=na_value)
235237
return BooleanDtype().__from_arrow__(values)
236238

@@ -336,7 +338,7 @@ def _str_startswith(
336338
and not isna(na)
337339
):
338340
result = result.fill_null(na)
339-
return self._predicate_result_converter(result)
341+
return self._predicate_result_converter(result, na=na)
340342

341343
def _str_endswith(
342344
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
@@ -361,7 +363,7 @@ def _str_endswith(
361363
and not isna(na)
362364
):
363365
result = result.fill_null(na)
364-
return self._predicate_result_converter(result)
366+
return self._predicate_result_converter(result, na=na)
365367

366368
def _str_replace(
367369
self,

pandas/tests/strings/test_find_replace.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,20 +311,31 @@ def test_startswith(pat, dtype, null_value, na):
311311

312312

313313
@pytest.mark.parametrize("na", [None, True, False])
314-
def test_startswith_nullable_string_dtype(nullable_string_dtype, na):
314+
def test_startswith_string_dtype(any_string_dtype, na):
315315
values = Series(
316316
["om", None, "foo_nom", "nom", "bar_foo", None, "foo", "regex", "rege."],
317-
dtype=nullable_string_dtype,
317+
dtype=any_string_dtype,
318318
)
319319
result = values.str.startswith("foo", na=na)
320+
321+
expected_dtype = (
322+
(object if na is None else bool)
323+
if is_object_or_nan_string_dtype(any_string_dtype)
324+
else "boolean"
325+
)
326+
if any_string_dtype == "str":
327+
# NaN propagates as False
328+
expected_dtype = bool
329+
if na is None:
330+
na = False
320331
exp = Series(
321-
[False, na, True, False, False, na, True, False, False], dtype="boolean"
332+
[False, na, True, False, False, na, True, False, False], dtype=expected_dtype
322333
)
323334
tm.assert_series_equal(result, exp)
324335

325336
result = values.str.startswith("rege.", na=na)
326337
exp = Series(
327-
[False, na, False, False, False, na, False, False, True], dtype="boolean"
338+
[False, na, False, False, False, na, False, False, True], dtype=expected_dtype
328339
)
329340
tm.assert_series_equal(result, exp)
330341

@@ -369,20 +380,30 @@ def test_endswith(pat, dtype, null_value, na):
369380

370381

371382
@pytest.mark.parametrize("na", [None, True, False])
372-
def test_endswith_nullable_string_dtype(nullable_string_dtype, na):
383+
def test_endswith_string_dtype(any_string_dtype, na):
373384
values = Series(
374385
["om", None, "foo_nom", "nom", "bar_foo", None, "foo", "regex", "rege."],
375-
dtype=nullable_string_dtype,
386+
dtype=any_string_dtype,
376387
)
377388
result = values.str.endswith("foo", na=na)
389+
expected_dtype = (
390+
(object if na is None else bool)
391+
if is_object_or_nan_string_dtype(any_string_dtype)
392+
else "boolean"
393+
)
394+
if any_string_dtype == "str":
395+
# NaN propagates as False
396+
expected_dtype = bool
397+
if na is None:
398+
na = False
378399
exp = Series(
379-
[False, na, False, False, True, na, True, False, False], dtype="boolean"
400+
[False, na, False, False, True, na, True, False, False], dtype=expected_dtype
380401
)
381402
tm.assert_series_equal(result, exp)
382403

383404
result = values.str.endswith("rege.", na=na)
384405
exp = Series(
385-
[False, na, False, False, False, na, False, False, True], dtype="boolean"
406+
[False, na, False, False, False, na, False, False, True], dtype=expected_dtype
386407
)
387408
tm.assert_series_equal(result, exp)
388409

0 commit comments

Comments
 (0)