Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v2.3.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ become the default string dtype in pandas 3.0. See

Bug fixes
^^^^^^^^^
-
- Fix regression in ``~Series.str.contains``, ``~Series.str.match`` and ``~Series.str.fullmatch``
with a compiled regex and custom flags (:issue:`62240`)

.. ---------------------------------------------------------------------------
.. _whatsnew_233.contributors:
Expand Down
14 changes: 4 additions & 10 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,29 +301,23 @@ def _str_contains(

def _str_match(
self,
pat: str | re.Pattern,
pat: str,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if isinstance(pat, re.Pattern):
# GH#61952
pat = pat.pattern
if isinstance(pat, str) and not pat.startswith("^"):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self,
pat: str | re.Pattern,
pat: str,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if isinstance(pat, re.Pattern):
# GH#61952
pat = pat.pattern
if isinstance(pat, str) and (not pat.endswith("$") or pat.endswith("\\$")):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def to_pyarrow_type(
return None


class ArrowExtensionArray(
class ArrowExtensionArray( # type: ignore[misc]
OpsMixin,
ExtensionArraySupportsAnyAll,
ArrowStringArrayMixin,
Expand Down
58 changes: 54 additions & 4 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from pandas._typing import (
ArrayLike,
Dtype,
Scalar,
Self,
npt,
)
Expand Down Expand Up @@ -329,8 +330,6 @@ def _data(self):
_str_startswith = ArrowStringArrayMixin._str_startswith
_str_endswith = ArrowStringArrayMixin._str_endswith
_str_pad = ArrowStringArrayMixin._str_pad
_str_match = ArrowStringArrayMixin._str_match
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
_str_lower = ArrowStringArrayMixin._str_lower
_str_upper = ArrowStringArrayMixin._str_upper
_str_strip = ArrowStringArrayMixin._str_strip
Expand All @@ -345,6 +344,28 @@ def _data(self):
_str_len = ArrowStringArrayMixin._str_len
_str_slice = ArrowStringArrayMixin._str_slice

@staticmethod
def _is_re_pattern_with_flags(pat: str | re.Pattern) -> bool:
# check if `pat` is a compiled regex pattern with flags that are not
# supported by pyarrow
return (
isinstance(pat, re.Pattern)
and (pat.flags & ~(re.IGNORECASE | re.UNICODE)) != 0
)

@staticmethod
def _preprocess_re_pattern(pat: re.Pattern, case: bool) -> tuple[str, bool, int]:
pattern = pat.pattern
flags = pat.flags
# flags is not supported by pyarrow, but `case` is -> extract and remove
if flags & re.IGNORECASE:
case = False
flags = flags & ~re.IGNORECASE
# when creating a pattern with re.compile and a string, it automatically
# gets a UNICODE flag, while pyarrow assumes unicode for strings anyway
flags = flags & ~re.UNICODE
return pattern, case, flags

def _str_contains(
self,
pat,
Expand All @@ -353,13 +374,42 @@ def _str_contains(
na=lib.no_default,
regex: bool = True,
):
if flags:
if flags or self._is_re_pattern_with_flags(pat):
return super()._str_contains(pat, case, flags, na, regex)
if isinstance(pat, re.Pattern):
pat = pat.pattern
# TODO flags passed separately by user are ignored
pat, case, flags = self._preprocess_re_pattern(pat, case)

return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)

def _str_match(
self,
pat: str | re.Pattern,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if flags or self._is_re_pattern_with_flags(pat):
return super()._str_match(pat, case, flags, na)
if isinstance(pat, re.Pattern):
pat, case, flags = self._preprocess_re_pattern(pat, case)

return ArrowStringArrayMixin._str_match(self, pat, case, flags, na)

def _str_fullmatch(
self,
pat: str | re.Pattern,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if flags or self._is_re_pattern_with_flags(pat):
return super()._str_fullmatch(pat, case, flags, na)
if isinstance(pat, re.Pattern):
pat, case, flags = self._preprocess_re_pattern(pat, case)

return ArrowStringArrayMixin._str_fullmatch(self, pat, case, flags, na)

def _str_replace(
self,
pat: str | re.Pattern,
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ def _str_match(
):
if not case:
flags |= re.IGNORECASE
if isinstance(pat, re.Pattern):
pat = pat.pattern

regex = re.compile(pat, flags=flags)

f = lambda x: regex.match(x) is not None
Expand All @@ -268,8 +267,7 @@ def _str_fullmatch(
):
if not case:
flags |= re.IGNORECASE
if isinstance(pat, re.Pattern):
pat = pat.pattern

regex = re.compile(pat, flags=flags)

f = lambda x: regex.fullmatch(x) is not None
Expand Down
101 changes: 96 additions & 5 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,60 @@ def test_contains_nan(any_string_dtype):

def test_contains_compiled_regex(any_string_dtype):
# GH#61942
ser = Series(["foo", "bar", "baz"], dtype=any_string_dtype)
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)

ser = Series(["foo", "bar", "Baz"], dtype=any_string_dtype)

pat = re.compile("ba.")
result = ser.str.contains(pat)
expected = Series([False, True, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

# TODO this currently works for pyarrow-backed dtypes but raises for python
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
result = ser.str.contains(pat, case=False)
expected = Series([False, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)
else:
with pytest.raises(
ValueError, match="cannot process flags argument with a compiled pattern"
):
ser.str.contains(pat, case=False)

pat = re.compile("ba.", flags=re.IGNORECASE)
result = ser.str.contains(pat)
expected = Series([False, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

# TODO should this be supported?
with pytest.raises(
ValueError, match="cannot process flags argument with a compiled pattern"
):
ser.str.contains(pat, flags=re.IGNORECASE)


def test_contains_compiled_regex_flags(any_string_dtype):
# ensure other (than ignorecase) flags are respected
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)

ser = Series(["foobar", "foo\nbar", "Baz"], dtype=any_string_dtype)

pat = re.compile("^ba")
result = ser.str.contains(pat)
expected = Series([False, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

pat = re.compile("^ba", flags=re.MULTILINE)
result = ser.str.contains(pat)
expected = Series([False, True, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

pat = re.compile("^ba", flags=re.MULTILINE | re.IGNORECASE)
result = ser.str.contains(pat)
expected = Series([False, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -837,14 +884,36 @@ def test_match_case_kwarg(any_string_dtype):

def test_match_compiled_regex(any_string_dtype):
# GH#61952
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
result = values.str.match(re.compile(r"ab"), case=False)
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)

values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)

result = values.str.match(re.compile("ab"))
expected = Series([True, False, True, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

# TODO this currently works for pyarrow-backed dtypes but raises for python
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
result = values.str.match(re.compile("ab"), case=False)
expected = Series([True, True, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)
else:
with pytest.raises(
ValueError, match="cannot process flags argument with a compiled pattern"
):
values.str.match(re.compile("ab"), case=False)

result = values.str.match(re.compile("ab", flags=re.IGNORECASE))
expected = Series([True, True, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

with pytest.raises(
ValueError, match="cannot process flags argument with a compiled pattern"
):
values.str.match(re.compile("ab"), flags=re.IGNORECASE)


# --------------------------------------------------------------------------------------
# str.fullmatch
Expand Down Expand Up @@ -917,14 +986,36 @@ def test_fullmatch_case_kwarg(any_string_dtype):

def test_fullmatch_compiled_regex(any_string_dtype):
# GH#61952
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
result = values.str.fullmatch(re.compile(r"ab"), case=False)
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)

values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)

result = values.str.fullmatch(re.compile("ab"))
expected = Series([True, False, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

# TODO this currently works for pyarrow-backed dtypes but raises for python
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
result = values.str.fullmatch(re.compile("ab"), case=False)
expected = Series([True, True, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)
else:
with pytest.raises(
ValueError, match="cannot process flags argument with a compiled pattern"
):
values.str.fullmatch(re.compile("ab"), case=False)

result = values.str.fullmatch(re.compile("ab", flags=re.IGNORECASE))
expected = Series([True, True, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

with pytest.raises(
ValueError, match="cannot process flags argument with a compiled pattern"
):
values.str.fullmatch(re.compile("ab"), flags=re.IGNORECASE)


# --------------------------------------------------------------------------------------
# str.findall
Expand Down
Loading