Skip to content

Commit 3428b95

Browse files
committed
BUG (string): ArrowStringArray.find corner cases
1 parent 220c18d commit 3428b95

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,7 +2426,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
24262426
if sub == "":
24272427
# GH 56792
24282428
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2429-
return type(self)(pa.chunked_array(result))
2429+
return self._convert_int_result(pa.chunked_array(result))
24302430
if start is None:
24312431
start_offset = 0
24322432
start = 0
@@ -2440,7 +2440,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
24402440
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
24412441
offset_result = pc.add(result, start_offset)
24422442
result = pc.if_else(found, offset_result, -1)
2443-
return type(self)(result)
2443+
return self._convert_int_result(result)
24442444

24452445
def _str_join(self, sep: str) -> Self:
24462446
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def astype(self, dtype, copy: bool = True):
280280
# String methods interface
281281

282282
_str_map = BaseStringArray._str_map
283+
_str_find = ArrowExtensionArray._str_find
283284

284285
def _str_contains(
285286
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -472,20 +473,6 @@ def _str_count(self, pat: str, flags: int = 0):
472473
result = pc.count_substring_regex(self._pa_array, pat)
473474
return self._convert_int_result(result)
474475

475-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
476-
if start != 0 and end is not None:
477-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
478-
result = pc.find_substring(slices, sub)
479-
not_found = pc.equal(result, -1)
480-
offset_result = pc.add(result, end - start)
481-
result = pc.if_else(not_found, result, offset_result)
482-
elif start == 0 and end is None:
483-
slices = self._pa_array
484-
result = pc.find_substring(slices, sub)
485-
else:
486-
return super()._str_find(sub, start, end)
487-
return self._convert_int_result(result)
488-
489476
def _str_get_dummies(self, sep: str = "|"):
490477
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
491478
if len(labels) == 0:

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import numpy as np
3333
import pytest
3434

35-
from pandas._config import using_string_dtype
36-
3735
from pandas._libs import lib
3836
from pandas._libs.tslibs import timezones
3937
from pandas.compat import (
@@ -1978,7 +1976,6 @@ def test_str_find_large_start():
19781976
tm.assert_series_equal(result, expected)
19791977

19801978

1981-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
19821979
@pytest.mark.skipif(
19831980
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
19841981
)
@@ -1990,11 +1987,15 @@ def test_str_find_e2e(start, end, sub):
19901987
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
19911988
dtype=ArrowDtype(pa.string()),
19921989
)
1993-
object_series = s.astype(pd.StringDtype())
1990+
object_series = s.astype(pd.StringDtype(storage="python"))
19941991
result = s.str.find(sub, start, end)
19951992
expected = object_series.str.find(sub, start, end).astype(result.dtype)
19961993
tm.assert_series_equal(result, expected)
19971994

1995+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
1996+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
1997+
tm.assert_series_equal(result2, expected)
1998+
19981999

19992000
def test_str_find_negative_start_negative_end_no_match():
20002001
# GH 56791

0 commit comments

Comments
 (0)