Skip to content

Commit f141c69

Browse files
authored
Fix annotations of str methods that accept regular expressions (pandas-dev#1278)
1 parent f340905 commit f141c69

File tree

2 files changed

+78
-10
lines changed

2 files changed

+78
-10
lines changed

pandas-stubs/core/strings.pyi

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,17 @@ class StringMethods(
8383
) -> _T_STR: ...
8484
@overload
8585
def split(
86-
self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ...
86+
self,
87+
pat: str | re.Pattern[str] = ...,
88+
*,
89+
n: int = ...,
90+
expand: Literal[True],
91+
regex: bool = ...,
8792
) -> _T_EXPANDING: ...
8893
@overload
8994
def split(
9095
self,
91-
pat: str = ...,
96+
pat: str | re.Pattern[str] = ...,
9297
*,
9398
n: int = ...,
9499
expand: Literal[False] = ...,
@@ -133,11 +138,15 @@ class StringMethods(
133138
regex: bool = ...,
134139
) -> _T_BOOL: ...
135140
def match(
136-
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
141+
self,
142+
pat: str | re.Pattern[str],
143+
case: bool = ...,
144+
flags: int = ...,
145+
na: Any = ...,
137146
) -> _T_BOOL: ...
138147
def replace(
139148
self,
140-
pat: str,
149+
pat: str | re.Pattern[str],
141150
repl: str | Callable[[re.Match[str]], str],
142151
n: int = ...,
143152
case: bool | None = ...,
@@ -180,18 +189,26 @@ class StringMethods(
180189
def count(self, pat: str, flags: int = ...) -> _T_INT: ...
181190
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
182191
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
183-
def findall(self, pat: str, flags: int = ...) -> _T_LIST_STR: ...
192+
def findall(self, pat: str | re.Pattern[str], flags: int = ...) -> _T_LIST_STR: ...
184193
@overload
185194
def extract(
186-
self, pat: str, flags: int = ..., *, expand: Literal[True] = ...
195+
self,
196+
pat: str | re.Pattern[str],
197+
flags: int = ...,
198+
*,
199+
expand: Literal[True] = ...,
187200
) -> pd.DataFrame: ...
188201
@overload
189-
def extract(self, pat: str, flags: int, expand: Literal[False]) -> _T_OBJECT: ...
202+
def extract(
203+
self, pat: str | re.Pattern[str], flags: int, expand: Literal[False]
204+
) -> _T_OBJECT: ...
190205
@overload
191206
def extract(
192-
self, pat: str, flags: int = ..., *, expand: Literal[False]
207+
self, pat: str | re.Pattern[str], flags: int = ..., *, expand: Literal[False]
193208
) -> _T_OBJECT: ...
194-
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
209+
def extractall(
210+
self, pat: str | re.Pattern[str], flags: int = ...
211+
) -> pd.DataFrame: ...
195212
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
196213
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
197214
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ...
@@ -214,7 +231,11 @@ class StringMethods(
214231
def isnumeric(self) -> _T_BOOL: ...
215232
def isdecimal(self) -> _T_BOOL: ...
216233
def fullmatch(
217-
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
234+
self,
235+
pat: str | re.Pattern[str],
236+
case: bool = ...,
237+
flags: int = ...,
238+
na: Any = ...,
218239
) -> _T_BOOL: ...
219240
def removeprefix(self, prefix: str) -> _T_STR: ...
220241
def removesuffix(self, suffix: str) -> _T_STR: ...

tests/test_string_accessors.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import pandas as pd
6+
import pytest
67
from typing_extensions import assert_type
78

89
from tests import (
@@ -44,6 +45,7 @@ def test_string_accessors_boolean_series():
4445
_check(assert_type(s.str.endswith("e"), "pd.Series[bool]"))
4546
_check(assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"))
4647
_check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"))
48+
_check(assert_type(s.str.fullmatch(re.compile(r"apple")), "pd.Series[bool]"))
4749
_check(assert_type(s.str.isalnum(), "pd.Series[bool]"))
4850
_check(assert_type(s.str.isalpha(), "pd.Series[bool]"))
4951
_check(assert_type(s.str.isdecimal(), "pd.Series[bool]"))
@@ -54,6 +56,7 @@ def test_string_accessors_boolean_series():
5456
_check(assert_type(s.str.istitle(), "pd.Series[bool]"))
5557
_check(assert_type(s.str.isupper(), "pd.Series[bool]"))
5658
_check(assert_type(s.str.match("pp"), "pd.Series[bool]"))
59+
_check(assert_type(s.str.match(re.compile(r"pp")), "pd.Series[bool]"))
5760

5861

5962
def test_string_accessors_boolean_index():
@@ -72,6 +75,7 @@ def test_string_accessors_boolean_index():
7275
_check(assert_type(idx.str.endswith("e"), np_ndarray_bool))
7376
_check(assert_type(idx.str.endswith(("e", "f")), np_ndarray_bool))
7477
_check(assert_type(idx.str.fullmatch("apple"), np_ndarray_bool))
78+
_check(assert_type(idx.str.fullmatch(re.compile(r"apple")), np_ndarray_bool))
7579
_check(assert_type(idx.str.isalnum(), np_ndarray_bool))
7680
_check(assert_type(idx.str.isalpha(), np_ndarray_bool))
7781
_check(assert_type(idx.str.isdecimal(), np_ndarray_bool))
@@ -82,6 +86,7 @@ def test_string_accessors_boolean_index():
8286
_check(assert_type(idx.str.istitle(), np_ndarray_bool))
8387
_check(assert_type(idx.str.isupper(), np_ndarray_bool))
8488
_check(assert_type(idx.str.match("pp"), np_ndarray_bool))
89+
_check(assert_type(idx.str.match(re.compile(r"pp")), np_ndarray_bool))
8590

8691

8792
def test_string_accessors_integer_series():
@@ -94,6 +99,10 @@ def test_string_accessors_integer_series():
9499
_check(assert_type(s.str.count("pp"), "pd.Series[int]"))
95100
_check(assert_type(s.str.len(), "pd.Series[int]"))
96101

102+
# unlike findall, find doesn't accept a compiled pattern
103+
with pytest.raises(TypeError):
104+
s.str.find(re.compile(r"p")) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
105+
97106

98107
def test_string_accessors_integer_index():
99108
idx = pd.Index(DATA)
@@ -105,6 +114,10 @@ def test_string_accessors_integer_index():
105114
_check(assert_type(idx.str.count("pp"), "pd.Index[int]"))
106115
_check(assert_type(idx.str.len(), "pd.Index[int]"))
107116

117+
# unlike findall, find doesn't accept a compiled pattern
118+
with pytest.raises(TypeError):
119+
idx.str.find(re.compile(r"p")) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
120+
108121

109122
def test_string_accessors_string_series():
110123
s = pd.Series(DATA)
@@ -123,6 +136,9 @@ def test_string_accessors_string_series():
123136
_check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"))
124137
_check(assert_type(s.str.repeat(2), "pd.Series[str]"))
125138
_check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"))
139+
_check(
140+
assert_type(s.str.replace(re.compile(r"a"), "X", regex=True), "pd.Series[str]")
141+
)
126142
_check(assert_type(s.str.rjust(80), "pd.Series[str]"))
127143
_check(assert_type(s.str.rstrip(), "pd.Series[str]"))
128144
_check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"))
@@ -158,6 +174,9 @@ def test_string_accessors_string_index():
158174
_check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]"))
159175
_check(assert_type(idx.str.repeat(2), "pd.Index[str]"))
160176
_check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]"))
177+
_check(
178+
assert_type(idx.str.replace(re.compile(r"a"), "X", regex=True), "pd.Index[str]")
179+
)
161180
_check(assert_type(idx.str.rjust(80), "pd.Index[str]"))
162181
_check(assert_type(idx.str.rstrip(), "pd.Index[str]"))
163182
_check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]"))
@@ -190,29 +209,49 @@ def test_string_accessors_list_series():
190209
s = pd.Series(DATA)
191210
_check = functools.partial(check, klass=pd.Series, dtype=list)
192211
_check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"))
212+
_check(assert_type(s.str.findall(re.compile(r"pp")), "pd.Series[list[str]]"))
193213
_check(assert_type(s.str.split("a"), "pd.Series[list[str]]"))
214+
_check(assert_type(s.str.split(re.compile(r"a")), "pd.Series[list[str]]"))
194215
# GH 194
195216
_check(assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"))
196217
_check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"))
197218
_check(assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"))
198219

220+
# rsplit doesn't accept compiled pattern
221+
# it doesn't raise at runtime but produces a nan
222+
bad_rsplit_result = s.str.rsplit(
223+
re.compile(r"a") # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
224+
)
225+
assert bad_rsplit_result.isna().all()
226+
199227

200228
def test_string_accessors_list_index():
201229
idx = pd.Index(DATA)
202230
_check = functools.partial(check, klass=pd.Index, dtype=list)
203231
_check(assert_type(idx.str.findall("pp"), "pd.Index[list[str]]"))
232+
_check(assert_type(idx.str.findall(re.compile(r"pp")), "pd.Index[list[str]]"))
204233
_check(assert_type(idx.str.split("a"), "pd.Index[list[str]]"))
234+
_check(assert_type(idx.str.split(re.compile(r"a")), "pd.Index[list[str]]"))
205235
# GH 194
206236
_check(assert_type(idx.str.split("a", expand=False), "pd.Index[list[str]]"))
207237
_check(assert_type(idx.str.rsplit("a"), "pd.Index[list[str]]"))
208238
_check(assert_type(idx.str.rsplit("a", expand=False), "pd.Index[list[str]]"))
209239

240+
# rsplit doesn't accept compiled pattern
241+
# it doesn't raise at runtime but produces a nan
242+
bad_rsplit_result = idx.str.rsplit(
243+
re.compile(r"a") # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
244+
)
245+
assert bad_rsplit_result.isna().all()
246+
210247

211248
def test_string_accessors_expanding_series():
212249
s = pd.Series(["a1", "b2", "c3"])
213250
_check = functools.partial(check, klass=pd.DataFrame)
214251
_check(assert_type(s.str.extract(r"([ab])?(\d)"), pd.DataFrame))
252+
_check(assert_type(s.str.extract(re.compile(r"([ab])?(\d)")), pd.DataFrame))
215253
_check(assert_type(s.str.extractall(r"([ab])?(\d)"), pd.DataFrame))
254+
_check(assert_type(s.str.extractall(re.compile(r"([ab])?(\d)")), pd.DataFrame))
216255
_check(assert_type(s.str.get_dummies(), pd.DataFrame))
217256
_check(assert_type(s.str.partition("p"), pd.DataFrame))
218257
_check(assert_type(s.str.rpartition("p"), pd.DataFrame))
@@ -231,7 +270,15 @@ def test_string_accessors_expanding_index():
231270

232271
# These ones are the odd ones out?
233272
check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
273+
check(
274+
assert_type(idx.str.extractall(re.compile(r"([ab])?(\d)")), pd.DataFrame),
275+
pd.DataFrame,
276+
)
234277
check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
278+
check(
279+
assert_type(idx.str.extract(re.compile(r"([ab])?(\d)")), pd.DataFrame),
280+
pd.DataFrame,
281+
)
235282

236283

237284
def test_series_overloads_partition():

0 commit comments

Comments
 (0)