Skip to content

Commit 66c2292

Browse files
committed
Add defaults to str accessor methods
1 parent 9a389ec commit 66c2292

File tree

2 files changed

+144
-101
lines changed

2 files changed

+144
-101
lines changed

pandas-stubs/core/strings/accessor.pyi

Lines changed: 94 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from builtins import slice as _slice
33
from collections.abc import (
44
Callable,
5+
Hashable,
6+
Mapping,
57
Sequence,
68
)
79
import re
810
from typing import (
9-
Any,
1011
Generic,
1112
Literal,
1213
TypeVar,
@@ -27,6 +28,7 @@ from pandas.core.base import NoNewAttributesMixin
2728
from pandas._libs.tslibs.nattype import NaTType
2829
from pandas._typing import (
2930
AlignJoin,
31+
DtypeObj,
3032
Scalar,
3133
T,
3234
np_ndarray_bool,
@@ -57,163 +59,163 @@ class StringMethods(
5759
@overload
5860
def cat(
5961
self,
60-
*,
61-
sep: str,
62-
na_rep: str | None = ...,
63-
join: AlignJoin = ...,
64-
) -> str: ...
65-
@overload
66-
def cat(
67-
self,
68-
others: Literal[None] = ...,
69-
*,
70-
sep: str,
71-
na_rep: str | None = ...,
72-
join: AlignJoin = ...,
62+
others: None = None,
63+
sep: str | None = None,
64+
na_rep: str | None = None,
65+
join: AlignJoin = "left",
7366
) -> str: ...
7467
@overload
7568
def cat(
7669
self,
7770
others: (
7871
Series[str] | Index[str] | pd.DataFrame | npt.NDArray[np.str_] | list[str]
7972
),
80-
sep: str = ...,
81-
na_rep: str | None = ...,
82-
join: AlignJoin = ...,
73+
sep: str | None = None,
74+
na_rep: str | None = None,
75+
join: AlignJoin = "left",
8376
) -> _T_STR: ...
8477
@overload
8578
def split(
8679
self,
87-
pat: str | re.Pattern[str] = ...,
80+
pat: str | re.Pattern[str] | None = None,
8881
*,
89-
n: int = ...,
82+
n: int = -1,
9083
expand: Literal[True],
91-
regex: bool = ...,
84+
regex: bool | None = None,
9285
) -> _T_EXPANDING: ...
9386
@overload
9487
def split(
9588
self,
96-
pat: str | re.Pattern[str] = ...,
89+
pat: str | re.Pattern[str] | None = None,
9790
*,
98-
n: int = ...,
99-
expand: Literal[False] = ...,
100-
regex: bool = ...,
91+
n: int = -1,
92+
expand: Literal[False] = False,
93+
regex: bool | None = None,
10194
) -> _T_LIST_STR: ...
10295
@overload
10396
def rsplit(
104-
self, pat: str = ..., *, n: int = ..., expand: Literal[True]
97+
self, pat: str | None = None, *, n: int = -1, expand: Literal[True]
10598
) -> _T_EXPANDING: ...
10699
@overload
107100
def rsplit(
108-
self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ...
101+
self, pat: str | None = None, *, n: int = -1, expand: Literal[False] = False
109102
) -> _T_LIST_STR: ...
110-
@overload
111-
def partition(self, sep: str = ...) -> _T_EXPANDING: ...
112-
@overload
113-
def partition(self, *, expand: Literal[True]) -> _T_EXPANDING: ...
114-
@overload
115-
def partition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ...
116-
@overload
103+
@overload # expand=True
104+
def partition(
105+
self, sep: str = " ", expand: Literal[True] = True
106+
) -> _T_EXPANDING: ...
107+
@overload # expand=False (positional argument)
117108
def partition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ...
118-
@overload
119-
def partition(self, *, expand: Literal[False]) -> _T_OBJECT: ...
120-
@overload
121-
def rpartition(self, sep: str = ...) -> _T_EXPANDING: ...
122-
@overload
123-
def rpartition(self, *, expand: Literal[True]) -> _T_EXPANDING: ...
124-
@overload
125-
def rpartition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ...
126-
@overload
109+
@overload # expand=False (keyword argument)
110+
def partition(self, sep: str = " ", *, expand: Literal[False]) -> _T_OBJECT: ...
111+
@overload # expand=True
112+
def rpartition(
113+
self, sep: str = " ", expand: Literal[True] = True
114+
) -> _T_EXPANDING: ...
115+
@overload # expand=False (positional argument)
127116
def rpartition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ...
128-
@overload
129-
def rpartition(self, *, expand: Literal[False]) -> _T_OBJECT: ...
130-
def get(self, i: int) -> _T_STR: ...
117+
@overload # expand=False (keyword argument)
118+
def rpartition(self, sep: str = " ", *, expand: Literal[False]) -> _T_OBJECT: ...
119+
def get(self, i: int | Hashable) -> _T_STR: ...
131120
def join(self, sep: str) -> _T_STR: ...
132121
def contains(
133122
self,
134123
pat: str | re.Pattern[str],
135-
case: bool = ...,
136-
flags: int = ...,
124+
case: bool = True,
125+
flags: int = 0,
137126
na: Scalar | NaTType | None = ...,
138-
regex: bool = ...,
127+
regex: bool = True,
139128
) -> _T_BOOL: ...
140129
def match(
141130
self,
142131
pat: str | re.Pattern[str],
143-
case: bool = ...,
144-
flags: int = ...,
145-
na: Any = ...,
132+
case: bool = True,
133+
flags: int = 0,
134+
na: Scalar | NaTType | None = ...,
135+
) -> _T_BOOL: ...
136+
def fullmatch(
137+
self,
138+
pat: str | re.Pattern[str],
139+
case: bool = True,
140+
flags: int = 0,
141+
na: Scalar | NaTType | None = ...,
146142
) -> _T_BOOL: ...
147143
def replace(
148144
self,
149145
pat: str | re.Pattern[str],
150146
repl: str | Callable[[re.Match[str]], str],
151-
n: int = ...,
152-
case: bool | None = ...,
153-
flags: int = ...,
154-
regex: bool = ...,
147+
n: int = -1,
148+
case: bool | None = None,
149+
flags: int = 0,
150+
regex: bool = False,
155151
) -> _T_STR: ...
156152
def repeat(self, repeats: int | Sequence[int]) -> _T_STR: ...
157153
def pad(
158154
self,
159155
width: int,
160-
side: Literal["left", "right", "both"] = ...,
161-
fillchar: str = ...,
156+
side: Literal["left", "right", "both"] = "left",
157+
fillchar: str = " ",
162158
) -> _T_STR: ...
163-
def center(self, width: int, fillchar: str = ...) -> _T_STR: ...
164-
def ljust(self, width: int, fillchar: str = ...) -> _T_STR: ...
165-
def rjust(self, width: int, fillchar: str = ...) -> _T_STR: ...
159+
def center(self, width: int, fillchar: str = " ") -> _T_STR: ...
160+
def ljust(self, width: int, fillchar: str = " ") -> _T_STR: ...
161+
def rjust(self, width: int, fillchar: str = " ") -> _T_STR: ...
166162
def zfill(self, width: int) -> _T_STR: ...
167163
def slice(
168-
self, start: int | None = ..., stop: int | None = ..., step: int | None = ...
164+
self, start: int | None = None, stop: int | None = None, step: int | None = None
169165
) -> T: ...
170166
def slice_replace(
171-
self, start: int | None = ..., stop: int | None = ..., repl: str | None = ...
167+
self, start: int | None = None, stop: int | None = None, repl: str | None = None
172168
) -> _T_STR: ...
173-
def decode(self, encoding: str, errors: str = ...) -> _T_STR: ...
174-
def encode(self, encoding: str, errors: str = ...) -> _T_BYTES: ...
175-
def strip(self, to_strip: str | None = ...) -> _T_STR: ...
176-
def lstrip(self, to_strip: str | None = ...) -> _T_STR: ...
177-
def rstrip(self, to_strip: str | None = ...) -> _T_STR: ...
169+
def decode(
170+
self, encoding: str, errors: str = "strict", dtype: str | DtypeObj | None = None
171+
) -> _T_STR: ...
172+
def encode(self, encoding: str, errors: str = "strict") -> _T_BYTES: ...
173+
def strip(self, to_strip: str | None = None) -> _T_STR: ...
174+
def lstrip(self, to_strip: str | None = None) -> _T_STR: ...
175+
def rstrip(self, to_strip: str | None = None) -> _T_STR: ...
176+
def removeprefix(self, prefix: str) -> _T_STR: ...
177+
def removesuffix(self, suffix: str) -> _T_STR: ...
178178
def wrap(
179179
self,
180180
width: int,
181-
expand_tabs: bool | None = ...,
182-
replace_whitespace: bool | None = ...,
183-
drop_whitespace: bool | None = ...,
184-
break_long_words: bool | None = ...,
185-
break_on_hyphens: bool | None = ...,
181+
*,
182+
# kwargs passed to textwrap.TextWrapper
183+
expand_tabs: bool = True,
184+
replace_whitespace: bool = True,
185+
drop_whitespace: bool = True,
186+
break_long_words: bool = True,
187+
break_on_hyphens: bool = True,
186188
) -> _T_STR: ...
187-
def get_dummies(self, sep: str = ...) -> _T_EXPANDING: ...
188-
def translate(self, table: dict[int, int | str | None] | None) -> _T_STR: ...
189-
def count(self, pat: str, flags: int = ...) -> _T_INT: ...
190-
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
191-
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
192-
def findall(self, pat: str | re.Pattern[str], flags: int = ...) -> _T_LIST_STR: ...
193-
@overload
189+
def get_dummies(self, sep: str = "|") -> _T_EXPANDING: ...
190+
def translate(self, table: Mapping[int, int | str | None] | None) -> _T_STR: ...
191+
def count(self, pat: str, flags: int = 0) -> _T_INT: ...
192+
def startswith(
193+
self, pat: str | tuple[str, ...], na: Scalar | NaTType | None = ...
194+
) -> _T_BOOL: ...
195+
def endswith(
196+
self, pat: str | tuple[str, ...], na: Scalar | NaTType | None = ...
197+
) -> _T_BOOL: ...
198+
def findall(self, pat: str | re.Pattern[str], flags: int = 0) -> _T_LIST_STR: ...
199+
@overload # expand=True
194200
def extract(
195-
self,
196-
pat: str | re.Pattern[str],
197-
flags: int = ...,
198-
*,
199-
expand: Literal[True] = ...,
201+
self, pat: str | re.Pattern[str], flags: int = 0, expand: Literal[True] = True
200202
) -> pd.DataFrame: ...
201-
@overload
203+
@overload # expand=False (positional argument)
202204
def extract(
203205
self, pat: str | re.Pattern[str], flags: int, expand: Literal[False]
204206
) -> _T_OBJECT: ...
205-
@overload
207+
@overload # expand=False (keyword argument)
206208
def extract(
207-
self, pat: str | re.Pattern[str], flags: int = ..., *, expand: Literal[False]
209+
self, pat: str | re.Pattern[str], flags: int = 0, *, expand: Literal[False]
208210
) -> _T_OBJECT: ...
209211
def extractall(
210-
self, pat: str | re.Pattern[str], flags: int = ...
212+
self, pat: str | re.Pattern[str], flags: int = 0
211213
) -> pd.DataFrame: ...
212-
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
213-
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
214+
def find(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ...
215+
def rfind(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ...
214216
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ...
215-
def index(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
216-
def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
217+
def index(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ...
218+
def rindex(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ...
217219
def len(self) -> _T_INT: ...
218220
def lower(self) -> _T_STR: ...
219221
def upper(self) -> _T_STR: ...
@@ -230,12 +232,3 @@ class StringMethods(
230232
def istitle(self) -> _T_BOOL: ...
231233
def isnumeric(self) -> _T_BOOL: ...
232234
def isdecimal(self) -> _T_BOOL: ...
233-
def fullmatch(
234-
self,
235-
pat: str | re.Pattern[str],
236-
case: bool = ...,
237-
flags: int = ...,
238-
na: Any = ...,
239-
) -> _T_BOOL: ...
240-
def removeprefix(self, prefix: str) -> _T_STR: ...
241-
def removesuffix(self, suffix: str) -> _T_STR: ...

tests/test_string_accessors.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ def test_string_accessors_string_series():
145145
check(assert_type(s.str.cat(sep="X"), str), str)
146146
_check(assert_type(s.str.center(10), "pd.Series[str]"))
147147
_check(assert_type(s.str.get(2), "pd.Series[str]"))
148+
s_dict = pd.Series( # example from the doc of str.get
149+
[{"name": "Hello", "value": "World"}, {"name": "Goodbye", "value": "Planet"}]
150+
)
151+
_check(assert_type(s_dict.str.get("name"), "pd.Series[str]"))
148152
_check(assert_type(s.str.ljust(80), "pd.Series[str]"))
149153
_check(assert_type(s.str.lower(), "pd.Series[str]"))
150154
_check(assert_type(s.str.lstrip("a"), "pd.Series[str]"))
@@ -166,14 +170,30 @@ def test_string_accessors_string_series():
166170
_check(
167171
assert_type(s.str.translate({241: "n"}), "pd.Series[str]"),
168172
)
173+
_check(
174+
assert_type(s.str.translate({241: 240}), "pd.Series[str]"),
175+
)
176+
trans_table: dict[int, int] = {ord("a"): ord("b")}
177+
_check( # tests covariance of table values (table is read-only)
178+
assert_type(s.str.translate(trans_table), "pd.Series[str]"),
179+
)
169180
_check(assert_type(s.str.upper(), "pd.Series[str]"))
170181
_check(assert_type(s.str.wrap(80), "pd.Series[str]"))
171182
_check(assert_type(s.str.zfill(10), "pd.Series[str]"))
172183
s_bytes = pd.Series([b"a1", b"b2", b"c3"])
173184
_check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"))
185+
_check(
186+
assert_type(
187+
s_bytes.str.decode("utf-8", dtype=pd.StringDtype()), "pd.Series[str]"
188+
)
189+
)
174190
s_list = pd.Series([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]])
175191
_check(assert_type(s_list.str.join("-"), "pd.Series[str]"))
176192

193+
# wrap doesn't accept positional arguments other than width
194+
with pytest.raises(TypeError):
195+
s.str.wrap(80, False) # type: ignore[misc] # pyright: ignore[reportCallIssue]
196+
177197

178198
def test_string_accessors_string_index():
179199
idx = pd.Index(DATA)
@@ -183,6 +203,10 @@ def test_string_accessors_string_index():
183203
check(assert_type(idx.str.cat(sep="X"), str), str)
184204
_check(assert_type(idx.str.center(10), "pd.Index[str]"))
185205
_check(assert_type(idx.str.get(2), "pd.Index[str]"))
206+
idx_dict = pd.Index(
207+
[{"name": "Hello", "value": "World"}, {"name": "Goodbye", "value": "Planet"}]
208+
)
209+
_check(assert_type(idx_dict.str.get("name"), "pd.Index[str]"))
186210
_check(assert_type(idx.str.ljust(80), "pd.Index[str]"))
187211
_check(assert_type(idx.str.lower(), "pd.Index[str]"))
188212
_check(assert_type(idx.str.lstrip("a"), "pd.Index[str]"))
@@ -204,14 +228,30 @@ def test_string_accessors_string_index():
204228
_check(
205229
assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"),
206230
)
231+
_check(
232+
assert_type(idx.str.translate({241: 240}), "pd.Index[str]"),
233+
)
234+
trans_table: dict[int, int] = {ord("a"): ord("b")}
235+
_check( # tests covariance of table values (table is read-only)
236+
assert_type(idx.str.translate(trans_table), "pd.Index[str]"),
237+
)
207238
_check(assert_type(idx.str.upper(), "pd.Index[str]"))
208239
_check(assert_type(idx.str.wrap(80), "pd.Index[str]"))
209240
_check(assert_type(idx.str.zfill(10), "pd.Index[str]"))
210241
idx_bytes = pd.Index([b"a1", b"b2", b"c3"])
211242
_check(assert_type(idx_bytes.str.decode("utf-8"), "pd.Index[str]"))
243+
_check(
244+
assert_type(
245+
idx_bytes.str.decode("utf-8", dtype=pd.StringDtype()), "pd.Index[str]"
246+
)
247+
)
212248
idx_list = pd.Index([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]])
213249
_check(assert_type(idx_list.str.join("-"), "pd.Index[str]"))
214250

251+
# wrap doesn't accept positional arguments other than width
252+
with pytest.raises(TypeError):
253+
idx.str.wrap(80, False) # type: ignore[misc] # pyright: ignore[reportCallIssue]
254+
215255

216256
def test_string_accessors_bytes_series():
217257
s = pd.Series(["a1", "b2", "c3"])
@@ -320,6 +360,11 @@ def test_series_overloads_partition():
320360
pd.Series,
321361
object,
322362
)
363+
check(
364+
assert_type(s.str.partition(expand=False), "pd.Series[type[object]]"),
365+
pd.Series,
366+
object,
367+
)
323368

324369
check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame)
325370
check(
@@ -330,6 +375,11 @@ def test_series_overloads_partition():
330375
pd.Series,
331376
object,
332377
)
378+
check(
379+
assert_type(s.str.rpartition(expand=False), "pd.Series[type[object]]"),
380+
pd.Series,
381+
object,
382+
)
333383

334384

335385
def test_index_overloads_partition():

0 commit comments

Comments
 (0)