diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 921a8cbd0..ed8d79461 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -275,7 +275,7 @@ class Index(IndexOpsMixin[S1]): Index[int], Index[bytes], Index[_str], - Index[type[object]], + Index, ]: ... @final def is_(self, other) -> bool: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index dcc6b89de..eedb84b02 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1181,7 +1181,7 @@ class Series(IndexOpsMixin[S1], NDFrame): Series[int], Series[bytes], Series[_str], - Series[type[object]], + Series, ]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... diff --git a/pandas-stubs/core/strings/accessor.pyi b/pandas-stubs/core/strings/accessor.pyi index fafdb9b33..bedd8ac7b 100644 --- a/pandas-stubs/core/strings/accessor.pyi +++ b/pandas-stubs/core/strings/accessor.pyi @@ -2,11 +2,12 @@ from builtins import slice as _slice from collections.abc import ( Callable, + Hashable, + Mapping, Sequence, ) import re from typing import ( - Any, Generic, Literal, TypeVar, @@ -27,6 +28,7 @@ from pandas.core.base import NoNewAttributesMixin from pandas._libs.tslibs.nattype import NaTType from pandas._typing import ( AlignJoin, + DtypeObj, Scalar, T, np_ndarray_bool, @@ -45,7 +47,7 @@ _T_BYTES = TypeVar("_T_BYTES", bound=Series[bytes] | Index[bytes]) # Used for the result of str.decode _T_STR = TypeVar("_T_STR", bound=Series[str] | Index[str]) # Used for the result of str.partition -_T_OBJECT = TypeVar("_T_OBJECT", bound=Series[type[object]] | Index[type[object]]) +_T_OBJECT = TypeVar("_T_OBJECT", bound=Series | Index) class StringMethods( NoNewAttributesMixin, @@ -57,19 +59,10 @@ class StringMethods( @overload def cat( self, - *, - sep: str, - na_rep: str | None = ..., - join: AlignJoin = ..., - ) -> str: ... - @overload - def cat( - self, - others: Literal[None] = ..., - *, - sep: str, - na_rep: str | None = ..., - join: AlignJoin = ..., + others: None = None, + sep: str | None = None, + na_rep: str | None = None, + join: AlignJoin = "left", ) -> str: ... @overload def cat( @@ -77,143 +70,152 @@ class StringMethods( others: ( Series[str] | Index[str] | pd.DataFrame | npt.NDArray[np.str_] | list[str] ), - sep: str = ..., - na_rep: str | None = ..., - join: AlignJoin = ..., + sep: str | None = None, + na_rep: str | None = None, + join: AlignJoin = "left", ) -> _T_STR: ... @overload def split( self, - pat: str | re.Pattern[str] = ..., + pat: str | re.Pattern[str] | None = None, *, - n: int = ..., + n: int = -1, expand: Literal[True], - regex: bool = ..., + regex: bool | None = None, ) -> _T_EXPANDING: ... @overload def split( self, - pat: str | re.Pattern[str] = ..., + pat: str | re.Pattern[str] | None = None, *, - n: int = ..., - expand: Literal[False] = ..., - regex: bool = ..., + n: int = -1, + expand: Literal[False] = False, + regex: bool | None = None, ) -> _T_LIST_STR: ... @overload def rsplit( - self, pat: str = ..., *, n: int = ..., expand: Literal[True] + self, pat: str | None = None, *, n: int = -1, expand: Literal[True] ) -> _T_EXPANDING: ... @overload def rsplit( - self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ... + self, pat: str | None = None, *, n: int = -1, expand: Literal[False] = False ) -> _T_LIST_STR: ... - @overload - def partition(self, sep: str = ...) -> _T_EXPANDING: ... - @overload - def partition(self, *, expand: Literal[True]) -> _T_EXPANDING: ... - @overload - def partition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ... - @overload + @overload # expand=True + def partition( + self, sep: str = " ", expand: Literal[True] = True + ) -> _T_EXPANDING: ... + @overload # expand=False (positional argument) def partition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ... - @overload - def partition(self, *, expand: Literal[False]) -> _T_OBJECT: ... - @overload - def rpartition(self, sep: str = ...) -> _T_EXPANDING: ... - @overload - def rpartition(self, *, expand: Literal[True]) -> _T_EXPANDING: ... - @overload - def rpartition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ... - @overload + @overload # expand=False (keyword argument) + def partition(self, sep: str = " ", *, expand: Literal[False]) -> _T_OBJECT: ... + @overload # expand=True + def rpartition( + self, sep: str = " ", expand: Literal[True] = True + ) -> _T_EXPANDING: ... + @overload # expand=False (positional argument) def rpartition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ... - @overload - def rpartition(self, *, expand: Literal[False]) -> _T_OBJECT: ... - def get(self, i: int) -> _T_STR: ... + @overload # expand=False (keyword argument) + def rpartition(self, sep: str = " ", *, expand: Literal[False]) -> _T_OBJECT: ... + def get(self, i: int | Hashable) -> _T_STR: ... def join(self, sep: str) -> _T_STR: ... def contains( self, pat: str | re.Pattern[str], - case: bool = ..., - flags: int = ..., + case: bool = True, + flags: int = 0, na: Scalar | NaTType | None = ..., - regex: bool = ..., + regex: bool = True, ) -> _T_BOOL: ... def match( self, pat: str | re.Pattern[str], - case: bool = ..., - flags: int = ..., - na: Any = ..., + case: bool = True, + flags: int = 0, + na: Scalar | NaTType | None = ..., + ) -> _T_BOOL: ... + def fullmatch( + self, + pat: str | re.Pattern[str], + case: bool = True, + flags: int = 0, + na: Scalar | NaTType | None = ..., ) -> _T_BOOL: ... def replace( self, pat: str | re.Pattern[str], repl: str | Callable[[re.Match[str]], str], - n: int = ..., - case: bool | None = ..., - flags: int = ..., - regex: bool = ..., + n: int = -1, + case: bool | None = None, + flags: int = 0, + regex: bool = False, ) -> _T_STR: ... def repeat(self, repeats: int | Sequence[int]) -> _T_STR: ... def pad( self, width: int, - side: Literal["left", "right", "both"] = ..., - fillchar: str = ..., + side: Literal["left", "right", "both"] = "left", + fillchar: str = " ", ) -> _T_STR: ... - def center(self, width: int, fillchar: str = ...) -> _T_STR: ... - def ljust(self, width: int, fillchar: str = ...) -> _T_STR: ... - def rjust(self, width: int, fillchar: str = ...) -> _T_STR: ... + def center(self, width: int, fillchar: str = " ") -> _T_STR: ... + def ljust(self, width: int, fillchar: str = " ") -> _T_STR: ... + def rjust(self, width: int, fillchar: str = " ") -> _T_STR: ... def zfill(self, width: int) -> _T_STR: ... def slice( - self, start: int | None = ..., stop: int | None = ..., step: int | None = ... + self, start: int | None = None, stop: int | None = None, step: int | None = None ) -> T: ... def slice_replace( - self, start: int | None = ..., stop: int | None = ..., repl: str | None = ... + self, start: int | None = None, stop: int | None = None, repl: str | None = None ) -> _T_STR: ... - def decode(self, encoding: str, errors: str = ...) -> _T_STR: ... - def encode(self, encoding: str, errors: str = ...) -> _T_BYTES: ... - def strip(self, to_strip: str | None = ...) -> _T_STR: ... - def lstrip(self, to_strip: str | None = ...) -> _T_STR: ... - def rstrip(self, to_strip: str | None = ...) -> _T_STR: ... + def decode( + self, encoding: str, errors: str = "strict", dtype: str | DtypeObj | None = None + ) -> _T_STR: ... + def encode(self, encoding: str, errors: str = "strict") -> _T_BYTES: ... + def strip(self, to_strip: str | None = None) -> _T_STR: ... + def lstrip(self, to_strip: str | None = None) -> _T_STR: ... + def rstrip(self, to_strip: str | None = None) -> _T_STR: ... + def removeprefix(self, prefix: str) -> _T_STR: ... + def removesuffix(self, suffix: str) -> _T_STR: ... def wrap( self, width: int, - expand_tabs: bool | None = ..., - replace_whitespace: bool | None = ..., - drop_whitespace: bool | None = ..., - break_long_words: bool | None = ..., - break_on_hyphens: bool | None = ..., + *, + # kwargs passed to textwrap.TextWrapper + expand_tabs: bool = True, + replace_whitespace: bool = True, + drop_whitespace: bool = True, + break_long_words: bool = True, + break_on_hyphens: bool = True, ) -> _T_STR: ... - def get_dummies(self, sep: str = ...) -> _T_EXPANDING: ... - def translate(self, table: dict[int, int | str | None] | None) -> _T_STR: ... - def count(self, pat: str, flags: int = ...) -> _T_INT: ... - def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... - def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... - def findall(self, pat: str | re.Pattern[str], flags: int = ...) -> _T_LIST_STR: ... - @overload + def get_dummies(self, sep: str = "|") -> _T_EXPANDING: ... + def translate(self, table: Mapping[int, int | str | None] | None) -> _T_STR: ... + def count(self, pat: str, flags: int = 0) -> _T_INT: ... + def startswith( + self, pat: str | tuple[str, ...], na: Scalar | NaTType | None = ... + ) -> _T_BOOL: ... + def endswith( + self, pat: str | tuple[str, ...], na: Scalar | NaTType | None = ... + ) -> _T_BOOL: ... + def findall(self, pat: str | re.Pattern[str], flags: int = 0) -> _T_LIST_STR: ... + @overload # expand=True def extract( - self, - pat: str | re.Pattern[str], - flags: int = ..., - *, - expand: Literal[True] = ..., + self, pat: str | re.Pattern[str], flags: int = 0, expand: Literal[True] = True ) -> pd.DataFrame: ... - @overload + @overload # expand=False (positional argument) def extract( self, pat: str | re.Pattern[str], flags: int, expand: Literal[False] ) -> _T_OBJECT: ... - @overload + @overload # expand=False (keyword argument) def extract( - self, pat: str | re.Pattern[str], flags: int = ..., *, expand: Literal[False] + self, pat: str | re.Pattern[str], flags: int = 0, *, expand: Literal[False] ) -> _T_OBJECT: ... def extractall( - self, pat: str | re.Pattern[str], flags: int = ... + self, pat: str | re.Pattern[str], flags: int = 0 ) -> pd.DataFrame: ... - def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... - def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def find(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ... + def rfind(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ... def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ... - def index(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... - def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def index(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ... + def rindex(self, sub: str, start: int = 0, end: int | None = None) -> _T_INT: ... def len(self) -> _T_INT: ... def lower(self) -> _T_STR: ... def upper(self) -> _T_STR: ... @@ -230,12 +232,3 @@ class StringMethods( def istitle(self) -> _T_BOOL: ... def isnumeric(self) -> _T_BOOL: ... def isdecimal(self) -> _T_BOOL: ... - def fullmatch( - self, - pat: str | re.Pattern[str], - case: bool = ..., - flags: int = ..., - na: Any = ..., - ) -> _T_BOOL: ... - def removeprefix(self, prefix: str) -> _T_STR: ... - def removesuffix(self, suffix: str) -> _T_STR: ... diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index ac3e58282..5796ee40d 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -145,6 +145,10 @@ def test_string_accessors_string_series(): check(assert_type(s.str.cat(sep="X"), str), str) _check(assert_type(s.str.center(10), "pd.Series[str]")) _check(assert_type(s.str.get(2), "pd.Series[str]")) + s_dict = pd.Series( # example from the doc of str.get + [{"name": "Hello", "value": "World"}, {"name": "Goodbye", "value": "Planet"}] + ) + _check(assert_type(s_dict.str.get("name"), "pd.Series[str]")) _check(assert_type(s.str.ljust(80), "pd.Series[str]")) _check(assert_type(s.str.lower(), "pd.Series[str]")) _check(assert_type(s.str.lstrip("a"), "pd.Series[str]")) @@ -166,14 +170,30 @@ def test_string_accessors_string_series(): _check( assert_type(s.str.translate({241: "n"}), "pd.Series[str]"), ) + _check( + assert_type(s.str.translate({241: 240}), "pd.Series[str]"), + ) + trans_table: dict[int, int] = {ord("a"): ord("b")} + _check( # tests covariance of table values (table is read-only) + assert_type(s.str.translate(trans_table), "pd.Series[str]"), + ) _check(assert_type(s.str.upper(), "pd.Series[str]")) _check(assert_type(s.str.wrap(80), "pd.Series[str]")) _check(assert_type(s.str.zfill(10), "pd.Series[str]")) s_bytes = pd.Series([b"a1", b"b2", b"c3"]) _check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]")) + _check( + assert_type( + s_bytes.str.decode("utf-8", dtype=pd.StringDtype()), "pd.Series[str]" + ) + ) s_list = pd.Series([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]]) _check(assert_type(s_list.str.join("-"), "pd.Series[str]")) + # wrap doesn't accept positional arguments other than width + if TYPE_CHECKING_INVALID_USAGE: + s.str.wrap(80, False) # type: ignore[misc] # pyright: ignore[reportCallIssue] + def test_string_accessors_string_index(): idx = pd.Index(DATA) @@ -183,6 +203,10 @@ def test_string_accessors_string_index(): check(assert_type(idx.str.cat(sep="X"), str), str) _check(assert_type(idx.str.center(10), "pd.Index[str]")) _check(assert_type(idx.str.get(2), "pd.Index[str]")) + idx_dict = pd.Index( + [{"name": "Hello", "value": "World"}, {"name": "Goodbye", "value": "Planet"}] + ) + _check(assert_type(idx_dict.str.get("name"), "pd.Index[str]")) _check(assert_type(idx.str.ljust(80), "pd.Index[str]")) _check(assert_type(idx.str.lower(), "pd.Index[str]")) _check(assert_type(idx.str.lstrip("a"), "pd.Index[str]")) @@ -204,14 +228,30 @@ def test_string_accessors_string_index(): _check( assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"), ) + _check( + assert_type(idx.str.translate({241: 240}), "pd.Index[str]"), + ) + trans_table: dict[int, int] = {ord("a"): ord("b")} + _check( # tests covariance of table values (table is read-only) + assert_type(idx.str.translate(trans_table), "pd.Index[str]"), + ) _check(assert_type(idx.str.upper(), "pd.Index[str]")) _check(assert_type(idx.str.wrap(80), "pd.Index[str]")) _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) idx_bytes = pd.Index([b"a1", b"b2", b"c3"]) _check(assert_type(idx_bytes.str.decode("utf-8"), "pd.Index[str]")) + _check( + assert_type( + idx_bytes.str.decode("utf-8", dtype=pd.StringDtype()), "pd.Index[str]" + ) + ) idx_list = pd.Index([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]]) _check(assert_type(idx_list.str.join("-"), "pd.Index[str]")) + # wrap doesn't accept positional arguments other than width + if TYPE_CHECKING_INVALID_USAGE: + idx.str.wrap(80, False) # type: ignore[misc] # pyright: ignore[reportCallIssue] + def test_string_accessors_bytes_series(): s = pd.Series(["a1", "b2", "c3"]) @@ -316,7 +356,12 @@ def test_series_overloads_partition(): assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame ) check( - assert_type(s.str.partition(sep=";", expand=False), "pd.Series[type[object]]"), + assert_type(s.str.partition(sep=";", expand=False), pd.Series), + pd.Series, + object, + ) + check( + assert_type(s.str.partition(expand=False), pd.Series), pd.Series, object, ) @@ -326,10 +371,11 @@ def test_series_overloads_partition(): assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame ) check( - assert_type(s.str.rpartition(sep=";", expand=False), "pd.Series[type[object]]"), + assert_type(s.str.rpartition(sep=";", expand=False), pd.Series), pd.Series, object, ) + check(assert_type(s.str.rpartition(expand=False), pd.Series), pd.Series, object) def test_index_overloads_partition(): @@ -350,7 +396,7 @@ def test_index_overloads_partition(): pd.MultiIndex, ) check( - assert_type(idx.str.partition(sep=";", expand=False), "pd.Index[type[object]]"), + assert_type(idx.str.partition(sep=";", expand=False), pd.Index), pd.Index, object, ) @@ -361,9 +407,7 @@ def test_index_overloads_partition(): pd.MultiIndex, ) check( - assert_type( - idx.str.rpartition(sep=";", expand=False), "pd.Index[type[object]]" - ), + assert_type(idx.str.rpartition(sep=";", expand=False), pd.Index), pd.Index, object, ) @@ -440,16 +484,12 @@ def test_series_overloads_extract(): assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame ) check( - assert_type( - s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]" - ), + assert_type(s.str.extract(r"[ab](\d)", expand=False), pd.Series), pd.Series, object, ) check( - assert_type( - s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]" - ), + assert_type(s.str.extract(r"[ab](\d)", re.IGNORECASE, False), pd.Series), pd.Series, object, ) @@ -463,16 +503,12 @@ def test_index_overloads_extract(): pd.DataFrame, ) check( - assert_type( - idx.str.extract(r"[ab](\d)", expand=False), "pd.Index[type[object]]" - ), + assert_type(idx.str.extract(r"[ab](\d)", expand=False), pd.Index), pd.Index, object, ) check( - assert_type( - idx.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Index[type[object]]" - ), + assert_type(idx.str.extract(r"[ab](\d)", re.IGNORECASE, False), pd.Index), pd.Index, object, )