Skip to content

Commit 7a99bdb

Browse files
committed
BUG (string): ArrowStringArray.find corner cases
1 parent 3a45265 commit 7a99bdb

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,7 +2380,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
23802380
if sub == "":
23812381
# GH 56792
23822382
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2383-
return type(self)(pa.chunked_array(result))
2383+
return self._convert_int_result(pa.chunked_array(result))
23842384
if start is None:
23852385
start_offset = 0
23862386
start = 0
@@ -2394,7 +2394,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
23942394
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
23952395
offset_result = pc.add(result, start_offset)
23962396
result = pc.if_else(found, offset_result, -1)
2397-
return type(self)(result)
2397+
return self._convert_int_result(result)
23982398

23992399
def _str_join(self, sep: str) -> Self:
24002400
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def astype(self, dtype, copy: bool = True):
293293
_str_startswith = ArrowStringArrayMixin._str_startswith
294294
_str_endswith = ArrowStringArrayMixin._str_endswith
295295
_str_pad = ArrowStringArrayMixin._str_pad
296+
_str_find = ArrowExtensionArray._str_find
296297

297298
def _str_contains(
298299
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -415,20 +416,6 @@ def _str_count(self, pat: str, flags: int = 0):
415416
result = pc.count_substring_regex(self._pa_array, pat)
416417
return self._convert_int_result(result)
417418

418-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
419-
if start != 0 and end is not None:
420-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
421-
result = pc.find_substring(slices, sub)
422-
not_found = pc.equal(result, -1)
423-
offset_result = pc.add(result, end - start)
424-
result = pc.if_else(not_found, result, offset_result)
425-
elif start == 0 and end is None:
426-
slices = self._pa_array
427-
result = pc.find_substring(slices, sub)
428-
else:
429-
return super()._str_find(sub, start, end)
430-
return self._convert_int_result(result)
431-
432419
def _str_get_dummies(self, sep: str = "|"):
433420
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
434421
if len(labels) == 0:

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import numpy as np
3333
import pytest
3434

35-
from pandas._config import using_string_dtype
36-
3735
from pandas._libs import lib
3836
from pandas._libs.tslibs import timezones
3937
from pandas.compat import (
@@ -1978,7 +1976,6 @@ def test_str_find_large_start():
19781976
tm.assert_series_equal(result, expected)
19791977

19801978

1981-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
19821979
@pytest.mark.skipif(
19831980
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
19841981
)
@@ -1990,11 +1987,15 @@ def test_str_find_e2e(start, end, sub):
19901987
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
19911988
dtype=ArrowDtype(pa.string()),
19921989
)
1993-
object_series = s.astype(pd.StringDtype())
1990+
object_series = s.astype(pd.StringDtype(storage="python"))
19941991
result = s.str.find(sub, start, end)
19951992
expected = object_series.str.find(sub, start, end).astype(result.dtype)
19961993
tm.assert_series_equal(result, expected)
19971994

1995+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
1996+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
1997+
tm.assert_series_equal(result2, expected)
1998+
19981999

19992000
def test_str_find_negative_start_negative_end_no_match():
20002001
# GH 56791

0 commit comments

Comments
 (0)