Skip to content

Commit 4afee55

Browse files
authored
BUG: fix pyarrow string regex replacement (#62283)
1 parent c99e52e commit 4afee55

File tree

4 files changed

+86
-3
lines changed

4 files changed

+86
-3
lines changed

doc/source/whatsnew/v2.3.3.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ become the default string dtype in pandas 3.0. See
2222

2323
Bug fixes
2424
^^^^^^^^^
25+
- Fix bug in :meth:`Series.str.replace` using named capture groups (e.g., ``\g<name>``) with the Arrow-backed dtype would raise an error (:issue:`57636`)
2526
- Fix regression in ``~Series.str.contains``, ``~Series.str.match`` and ``~Series.str.fullmatch``
2627
with a compiled regex and custom flags (:issue:`62240`)
2728

pandas/core/arrays/_arrow_string_mixins.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,20 @@ def _str_replace(
167167
flags: int = 0,
168168
regex: bool = True,
169169
) -> Self:
170-
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
170+
if (
171+
isinstance(pat, re.Pattern)
172+
or callable(repl)
173+
or not case
174+
or flags
175+
or (
176+
isinstance(repl, str)
177+
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
178+
)
179+
):
171180
raise NotImplementedError(
172181
"replace is not supported with a re.Pattern, callable repl, "
173-
"case=False, or flags!=0"
182+
"case=False, flags!=0, or when the replacement string contains "
183+
"named group references (\\g<...>, \\d+)"
174184
)
175185

176186
func = pc.replace_substring_regex if regex else pc.replace_substring

pandas/core/arrays/string_arrow.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,17 @@ def _str_replace(
427427
flags: int = 0,
428428
regex: bool = True,
429429
):
430-
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
430+
if (
431+
isinstance(pat, re.Pattern)
432+
or callable(repl)
433+
or not case
434+
or flags
435+
or ( # substitution contains a named group pattern
436+
# https://docs.python.org/3/library/re.html
437+
isinstance(repl, str)
438+
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
439+
)
440+
):
431441
return super()._str_replace(pat, repl, n, case, flags, regex)
432442

433443
return ArrowStringArrayMixin._str_replace(

pandas/tests/strings/test_find_replace.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,68 @@ def test_replace_callable_raises(any_string_dtype, repl):
592592
values.str.replace("a", repl, regex=True)
593593

594594

595+
@pytest.mark.parametrize(
596+
"repl, expected_list",
597+
[
598+
(
599+
r"\g<three> \g<two> \g<one>",
600+
["Three Two One", "Baz Bar Foo"],
601+
),
602+
(
603+
r"\g<3> \g<2> \g<1>",
604+
["Three Two One", "Baz Bar Foo"],
605+
),
606+
(
607+
r"\g<2>0",
608+
["Two0", "Bar0"],
609+
),
610+
(
611+
r"\g<2>0 \1",
612+
["Two0 One", "Bar0 Foo"],
613+
),
614+
],
615+
ids=[
616+
"named_groups_full_swap",
617+
"numbered_groups_full_swap",
618+
"single_group_with_literal",
619+
"mixed_group_reference_with_literal",
620+
],
621+
)
622+
@pytest.mark.parametrize("use_compile", [True, False])
623+
def test_replace_named_groups_regex_swap(
624+
any_string_dtype, use_compile, repl, expected_list
625+
):
626+
# GH#57636
627+
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype)
628+
pattern = r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)"
629+
if use_compile:
630+
pattern = re.compile(pattern)
631+
result = ser.str.replace(pattern, repl, regex=True)
632+
expected = Series(expected_list, dtype=any_string_dtype)
633+
tm.assert_series_equal(result, expected)
634+
635+
636+
@pytest.mark.parametrize(
637+
"repl",
638+
[
639+
r"\g<20>",
640+
r"\20",
641+
],
642+
)
643+
@pytest.mark.parametrize("use_compile", [True, False])
644+
def test_replace_named_groups_regex_swap_expected_fail(
645+
any_string_dtype, repl, use_compile
646+
):
647+
# GH#57636
648+
pattern = r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)"
649+
if use_compile:
650+
pattern = re.compile(pattern)
651+
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype)
652+
653+
with pytest.raises(re.error, match="invalid group reference"):
654+
ser.str.replace(pattern, repl, regex=True)
655+
656+
595657
def test_replace_callable_named_groups(any_string_dtype):
596658
# test regex named groups
597659
ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype)

0 commit comments

Comments
 (0)