Skip to content

Commit 467ce1d

Browse files
committed
BUG (string): ArrowStringArray.find corner cases
1 parent db1b8ab commit 467ce1d

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2395,7 +2395,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
23952395
if sub == "":
23962396
# GH 56792
23972397
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2398-
return type(self)(pa.chunked_array(result))
2398+
return self._convert_int_result(pa.chunked_array(result))
23992399
if start is None:
24002400
start_offset = 0
24012401
start = 0
@@ -2409,7 +2409,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
24092409
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
24102410
offset_result = pc.add(result, start_offset)
24112411
result = pc.if_else(found, offset_result, -1)
2412-
return type(self)(result)
2412+
return self._convert_int_result(result)
24132413

24142414
def _str_join(self, sep: str) -> Self:
24152415
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def astype(self, dtype, copy: bool = True):
285285
_str_startswith = ArrowStringArrayMixin._str_startswith
286286
_str_endswith = ArrowStringArrayMixin._str_endswith
287287
_str_pad = ArrowStringArrayMixin._str_pad
288+
_str_find = ArrowExtensionArray._str_find
288289

289290
def _str_contains(
290291
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -447,20 +448,6 @@ def _str_count(self, pat: str, flags: int = 0):
447448
result = pc.count_substring_regex(self._pa_array, pat)
448449
return self._convert_int_result(result)
449450

450-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
451-
if start != 0 and end is not None:
452-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
453-
result = pc.find_substring(slices, sub)
454-
not_found = pc.equal(result, -1)
455-
offset_result = pc.add(result, end - start)
456-
result = pc.if_else(not_found, result, offset_result)
457-
elif start == 0 and end is None:
458-
slices = self._pa_array
459-
result = pc.find_substring(slices, sub)
460-
else:
461-
return super()._str_find(sub, start, end)
462-
return self._convert_int_result(result)
463-
464451
def _str_get_dummies(self, sep: str = "|"):
465452
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
466453
if len(labels) == 0:

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import numpy as np
3333
import pytest
3434

35-
from pandas._config import using_string_dtype
36-
3735
from pandas._libs import lib
3836
from pandas._libs.tslibs import timezones
3937
from pandas.compat import (
@@ -1978,7 +1976,6 @@ def test_str_find_large_start():
19781976
tm.assert_series_equal(result, expected)
19791977

19801978

1981-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
19821979
@pytest.mark.skipif(
19831980
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
19841981
)
@@ -1990,11 +1987,15 @@ def test_str_find_e2e(start, end, sub):
19901987
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
19911988
dtype=ArrowDtype(pa.string()),
19921989
)
1993-
object_series = s.astype(pd.StringDtype())
1990+
object_series = s.astype(pd.StringDtype(storage="python"))
19941991
result = s.str.find(sub, start, end)
19951992
expected = object_series.str.find(sub, start, end).astype(result.dtype)
19961993
tm.assert_series_equal(result, expected)
19971994

1995+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
1996+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
1997+
tm.assert_series_equal(result2, expected)
1998+
19981999

19992000
def test_str_find_negative_start_negative_end_no_match():
20002001
# GH 56791

0 commit comments

Comments
 (0)