diff --git a/doc/source/whatsnew/v2.3.3.rst b/doc/source/whatsnew/v2.3.3.rst index e31ae4a8a647b..cbde6f52d4472 100644 --- a/doc/source/whatsnew/v2.3.3.rst +++ b/doc/source/whatsnew/v2.3.3.rst @@ -22,7 +22,8 @@ become the default string dtype in pandas 3.0. See Bug fixes ^^^^^^^^^ -- +- Fix regression in ``~Series.str.contains``, ``~Series.str.match`` and ``~Series.str.fullmatch`` + with a compiled regex and custom flags (:issue:`62240`) .. --------------------------------------------------------------------------- .. _whatsnew_233.contributors: diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 7e9b084330111..90de41ffb63fa 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -301,29 +301,23 @@ def _str_contains( def _str_match( self, - pat: str | re.Pattern, + pat: str, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): - if isinstance(pat, re.Pattern): - # GH#61952 - pat = pat.pattern - if isinstance(pat, str) and not pat.startswith("^"): + if not pat.startswith("^"): pat = f"^{pat}" return self._str_contains(pat, case, flags, na, regex=True) def _str_fullmatch( self, - pat: str | re.Pattern, + pat: str, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): - if isinstance(pat, re.Pattern): - # GH#61952 - pat = pat.pattern - if isinstance(pat, str) and (not pat.endswith("$") or pat.endswith("\\$")): + if not pat.endswith("$") or pat.endswith("\\$"): pat = f"{pat}$" return self._str_match(pat, case, flags, na) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index feb5ff111c8c9..52177300cc3ea 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -238,7 +238,7 @@ def to_pyarrow_type( return None -class ArrowExtensionArray( +class ArrowExtensionArray( # type: ignore[misc] OpsMixin, ExtensionArraySupportsAnyAll, ArrowStringArrayMixin, diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index a777befe25dc0..15d001699064a 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -51,6 +51,7 @@ from pandas._typing import ( ArrayLike, Dtype, + Scalar, Self, npt, ) @@ -329,8 +330,6 @@ def _data(self): _str_startswith = ArrowStringArrayMixin._str_startswith _str_endswith = ArrowStringArrayMixin._str_endswith _str_pad = ArrowStringArrayMixin._str_pad - _str_match = ArrowStringArrayMixin._str_match - _str_fullmatch = ArrowStringArrayMixin._str_fullmatch _str_lower = ArrowStringArrayMixin._str_lower _str_upper = ArrowStringArrayMixin._str_upper _str_strip = ArrowStringArrayMixin._str_strip @@ -345,6 +344,28 @@ def _data(self): _str_len = ArrowStringArrayMixin._str_len _str_slice = ArrowStringArrayMixin._str_slice + @staticmethod + def _is_re_pattern_with_flags(pat: str | re.Pattern) -> bool: + # check if `pat` is a compiled regex pattern with flags that are not + # supported by pyarrow + return ( + isinstance(pat, re.Pattern) + and (pat.flags & ~(re.IGNORECASE | re.UNICODE)) != 0 + ) + + @staticmethod + def _preprocess_re_pattern(pat: re.Pattern, case: bool) -> tuple[str, bool, int]: + pattern = pat.pattern + flags = pat.flags + # flags is not supported by pyarrow, but `case` is -> extract and remove + if flags & re.IGNORECASE: + case = False + flags = flags & ~re.IGNORECASE + # when creating a pattern with re.compile and a string, it automatically + # gets a UNICODE flag, while pyarrow assumes unicode for strings anyway + flags = flags & ~re.UNICODE + return pattern, case, flags + def _str_contains( self, pat, @@ -353,13 +374,42 @@ def _str_contains( na=lib.no_default, regex: bool = True, ): - if flags: + if flags or self._is_re_pattern_with_flags(pat): return super()._str_contains(pat, case, flags, na, regex) if isinstance(pat, re.Pattern): - pat = pat.pattern + # TODO flags passed separately by user are ignored + pat, case, flags = self._preprocess_re_pattern(pat, case) return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex) + def _str_match( + self, + pat: str | re.Pattern, + case: bool = True, + flags: int = 0, + na: Scalar | lib.NoDefault = lib.no_default, + ): + if flags or self._is_re_pattern_with_flags(pat): + return super()._str_match(pat, case, flags, na) + if isinstance(pat, re.Pattern): + pat, case, flags = self._preprocess_re_pattern(pat, case) + + return ArrowStringArrayMixin._str_match(self, pat, case, flags, na) + + def _str_fullmatch( + self, + pat: str | re.Pattern, + case: bool = True, + flags: int = 0, + na: Scalar | lib.NoDefault = lib.no_default, + ): + if flags or self._is_re_pattern_with_flags(pat): + return super()._str_fullmatch(pat, case, flags, na) + if isinstance(pat, re.Pattern): + pat, case, flags = self._preprocess_re_pattern(pat, case) + + return ArrowStringArrayMixin._str_fullmatch(self, pat, case, flags, na) + def _str_replace( self, pat: str | re.Pattern, diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 30696141d552b..05ffee662dbac 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -252,8 +252,7 @@ def _str_match( ): if not case: flags |= re.IGNORECASE - if isinstance(pat, re.Pattern): - pat = pat.pattern + regex = re.compile(pat, flags=flags) f = lambda x: regex.match(x) is not None @@ -268,8 +267,7 @@ def _str_fullmatch( ): if not case: flags |= re.IGNORECASE - if isinstance(pat, re.Pattern): - pat = pat.pattern + regex = re.compile(pat, flags=flags) f = lambda x: regex.fullmatch(x) is not None diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index 48ac3250b6060..3f57754af6e79 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -292,13 +292,60 @@ def test_contains_nan(any_string_dtype): def test_contains_compiled_regex(any_string_dtype): # GH#61942 - ser = Series(["foo", "bar", "baz"], dtype=any_string_dtype) + expected_dtype = ( + np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" + ) + + ser = Series(["foo", "bar", "Baz"], dtype=any_string_dtype) + pat = re.compile("ba.") result = ser.str.contains(pat) + expected = Series([False, True, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # TODO this currently works for pyarrow-backed dtypes but raises for python + if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow": + result = ser.str.contains(pat, case=False) + expected = Series([False, True, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + else: + with pytest.raises( + ValueError, match="cannot process flags argument with a compiled pattern" + ): + ser.str.contains(pat, case=False) + + pat = re.compile("ba.", flags=re.IGNORECASE) + result = ser.str.contains(pat) + expected = Series([False, True, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + # TODO should this be supported? + with pytest.raises( + ValueError, match="cannot process flags argument with a compiled pattern" + ): + ser.str.contains(pat, flags=re.IGNORECASE) + + +def test_contains_compiled_regex_flags(any_string_dtype): + # ensure other (than ignorecase) flags are respected expected_dtype = ( np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" ) + + ser = Series(["foobar", "foo\nbar", "Baz"], dtype=any_string_dtype) + + pat = re.compile("^ba") + result = ser.str.contains(pat) + expected = Series([False, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + pat = re.compile("^ba", flags=re.MULTILINE) + result = ser.str.contains(pat) + expected = Series([False, True, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + pat = re.compile("^ba", flags=re.MULTILINE | re.IGNORECASE) + result = ser.str.contains(pat) expected = Series([False, True, True], dtype=expected_dtype) tm.assert_series_equal(result, expected) @@ -837,14 +884,36 @@ def test_match_case_kwarg(any_string_dtype): def test_match_compiled_regex(any_string_dtype): # GH#61952 - values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) - result = values.str.match(re.compile(r"ab"), case=False) expected_dtype = ( np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" ) + + values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + + result = values.str.match(re.compile("ab")) + expected = Series([True, False, True, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # TODO this currently works for pyarrow-backed dtypes but raises for python + if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow": + result = values.str.match(re.compile("ab"), case=False) + expected = Series([True, True, True, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + else: + with pytest.raises( + ValueError, match="cannot process flags argument with a compiled pattern" + ): + values.str.match(re.compile("ab"), case=False) + + result = values.str.match(re.compile("ab", flags=re.IGNORECASE)) expected = Series([True, True, True, True], dtype=expected_dtype) tm.assert_series_equal(result, expected) + with pytest.raises( + ValueError, match="cannot process flags argument with a compiled pattern" + ): + values.str.match(re.compile("ab"), flags=re.IGNORECASE) + # -------------------------------------------------------------------------------------- # str.fullmatch @@ -917,14 +986,36 @@ def test_fullmatch_case_kwarg(any_string_dtype): def test_fullmatch_compiled_regex(any_string_dtype): # GH#61952 - values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) - result = values.str.fullmatch(re.compile(r"ab"), case=False) expected_dtype = ( np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" ) + + values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + + result = values.str.fullmatch(re.compile("ab")) + expected = Series([True, False, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # TODO this currently works for pyarrow-backed dtypes but raises for python + if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow": + result = values.str.fullmatch(re.compile("ab"), case=False) + expected = Series([True, True, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + else: + with pytest.raises( + ValueError, match="cannot process flags argument with a compiled pattern" + ): + values.str.fullmatch(re.compile("ab"), case=False) + + result = values.str.fullmatch(re.compile("ab", flags=re.IGNORECASE)) expected = Series([True, True, False, False], dtype=expected_dtype) tm.assert_series_equal(result, expected) + with pytest.raises( + ValueError, match="cannot process flags argument with a compiled pattern" + ): + values.str.fullmatch(re.compile("ab"), flags=re.IGNORECASE) + # -------------------------------------------------------------------------------------- # str.findall