diff --git a/doc/source/whatsnew/v2.3.2.rst b/doc/source/whatsnew/v2.3.2.rst index 03244c808ad03..478bd945a3735 100644 --- a/doc/source/whatsnew/v2.3.2.rst +++ b/doc/source/whatsnew/v2.3.2.rst @@ -25,7 +25,8 @@ Bug fixes - Fix :meth:`~DataFrame.to_json` with ``orient="table"`` to correctly use the "string" type in the JSON Table Schema for :class:`StringDtype` columns (:issue:`61889`) - +- Fixed ``~Series.str.match`` and ``~Series.str.fullmatch`` with compiled regex + for the Arrow-backed string dtype (:issue:`61964`) .. --------------------------------------------------------------------------- .. _whatsnew_232.contributors: diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index e136b4f92031d..7e9b084330111 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -301,23 +301,29 @@ def _str_contains( def _str_match( self, - pat: str, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): - if not pat.startswith("^"): + if isinstance(pat, re.Pattern): + # GH#61952 + pat = pat.pattern + if isinstance(pat, str) and not pat.startswith("^"): pat = f"^{pat}" return self._str_contains(pat, case, flags, na, regex=True) def _str_fullmatch( self, - pat, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): - if not pat.endswith("$") or pat.endswith("\\$"): + if isinstance(pat, re.Pattern): + # GH#61952 + pat = pat.pattern + if isinstance(pat, str) and (not pat.endswith("$") or pat.endswith("\\$")): pat = f"{pat}$" return self._str_match(pat, case, flags, na) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index c0e458f7968e7..da17543f8470d 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -1353,8 +1353,8 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=lib.no_default): Parameters ---------- - pat : str - Character sequence. + pat : str or compiled regex + Character sequence or regular expression. case : bool, default True If True, case sensitive. flags : int, default 0 (no flags) diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index e82c6c20e86d9..30696141d552b 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -245,14 +245,15 @@ def rep(x, r): def _str_match( self, - pat: str, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): if not case: flags |= re.IGNORECASE - + if isinstance(pat, re.Pattern): + pat = pat.pattern regex = re.compile(pat, flags=flags) f = lambda x: regex.match(x) is not None @@ -267,7 +268,8 @@ def _str_fullmatch( ): if not case: flags |= re.IGNORECASE - + if isinstance(pat, re.Pattern): + pat = pat.pattern regex = re.compile(pat, flags=flags) f = lambda x: regex.fullmatch(x) is not None diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index dfa9a36995480..2bda4b8fba434 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -822,6 +822,17 @@ def test_match_case_kwarg(any_string_dtype): tm.assert_series_equal(result, expected) +def test_match_compiled_regex(any_string_dtype): + # GH#61952 + values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + result = values.str.match(re.compile(r"ab"), case=False) + expected_dtype = ( + np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" + ) + expected = Series([True, True, True, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # -------------------------------------------------------------------------------------- # str.fullmatch # -------------------------------------------------------------------------------------- @@ -891,6 +902,17 @@ def test_fullmatch_case_kwarg(any_string_dtype): tm.assert_series_equal(result, expected) +def test_fullmatch_compiled_regex(any_string_dtype): + # GH#61952 + values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + result = values.str.fullmatch(re.compile(r"ab"), case=False) + expected_dtype = ( + np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" + ) + expected = Series([True, True, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # -------------------------------------------------------------------------------------- # str.findall # --------------------------------------------------------------------------------------