Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ S1 = TypeVar(
| Period
| Interval
| CategoricalDtype
| BaseOffset,
| BaseOffset
| list[str],
)

S2 = TypeVar(
Expand All @@ -566,7 +567,8 @@ S2 = TypeVar(
| Period
| Interval
| CategoricalDtype
| BaseOffset,
| BaseOffset
| list[str],
)

IndexingInt: TypeAlias = (
Expand Down
4 changes: 3 additions & 1 deletion pandas-stubs/core/indexes/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ class Index(IndexOpsMixin[S1]):
**kwargs,
) -> Self: ...
@property
def str(self) -> StringMethods[Self, MultiIndex, np_ndarray_bool]: ...
def str(
self,
) -> StringMethods[Self, MultiIndex, np_ndarray_bool, Index[list[str]]]: ...
def is_(self, other) -> bool: ...
def __len__(self) -> int: ...
def __array__(self, dtype=...) -> np.ndarray: ...
Expand Down
24 changes: 23 additions & 1 deletion pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,26 @@ class Series(IndexOpsMixin[S1], NDFrame):
copy: bool = ...,
) -> Series[Any]: ...
@overload
def __new__(
cls,
data: Sequence[list[str]],
index: Axes | None = ...,
*,
dtype: Dtype = ...,
name: Hashable = ...,
copy: bool = ...,
) -> Series[list[str]]: ...
@overload
def __new__(
cls,
data: Sequence[str],
index: Axes | None = ...,
*,
dtype: Dtype = ...,
name: Hashable = ...,
copy: bool = ...,
) -> Series[str]: ...
@overload
def __new__(
cls,
data: (
Expand Down Expand Up @@ -1199,7 +1219,9 @@ class Series(IndexOpsMixin[S1], NDFrame):
) -> Series[S1]: ...
def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ...
@property
def str(self) -> StringMethods[Series, DataFrame, Series[bool]]: ...
def str(
self,
) -> StringMethods[Series, DataFrame, Series[bool], Series[list[str]]]: ...
@property
def dt(self) -> CombinedDatetimelikeProperties: ...
@property
Expand Down
14 changes: 11 additions & 3 deletions pandas-stubs/core/strings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import numpy as np
import pandas as pd
from pandas import (
DataFrame,
Index,
MultiIndex,
Series,
)
Expand All @@ -28,10 +29,12 @@ from pandas._typing import (

# The _TS type is what is used for the result of str.split with expand=True
_TS = TypeVar("_TS", DataFrame, MultiIndex)
# The _TS2 type is what is used for the result of str.split with expand=False
_TS2 = TypeVar("_TS2", Series[list[str]], Index[list[str]])
# The _TM type is what is used for the result of str.match
_TM = TypeVar("_TM", Series[bool], np_ndarray_bool)

class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM]):
class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
def __init__(self, data: T) -> None: ...
def __getitem__(self, key: slice | int) -> T: ...
def __iter__(self) -> T: ...
Expand Down Expand Up @@ -66,8 +69,13 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM]):
) -> _TS: ...
@overload
def split(
self, pat: str = ..., *, n: int = ..., expand: bool = ..., regex: bool = ...
) -> T: ...
self,
pat: str = ...,
*,
n: int = ...,
expand: Literal[False] = ...,
regex: bool = ...,
) -> _TS2: ...
@overload
def rsplit(self, pat: str = ..., *, n: int = ..., expand: Literal[True]) -> _TS: ...
@overload
Expand Down
7 changes: 6 additions & 1 deletion tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,13 @@ def test_difference_none() -> None:
def test_str_split() -> None:
# GH 194
ind = pd.Index(["a-b", "c-d"])
check(assert_type(ind.str.split("-"), "pd.Index[str]"), pd.Index)
check(assert_type(ind.str.split("-"), "pd.Index[list[str]]"), pd.Index)
check(assert_type(ind.str.split("-", expand=True), pd.MultiIndex), pd.MultiIndex)
check(
assert_type(ind.str.split("-", expand=False), "pd.Index[list[str]]"),
pd.Index,
list,
)


def test_str_match() -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,9 +1553,14 @@ def test_string_accessors():
check(assert_type(s.str.rstrip(), pd.Series), pd.Series)
check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series)
check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series)
check(assert_type(s.str.split("a"), pd.Series), pd.Series)
check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list)
# GH 194
check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame)
check(
assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"),
pd.Series,
list,
)
check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_)
check(
assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"),
Expand Down
Loading