Skip to content

Commit 2bf306d

Browse files
GH1159 Disallow StringMethods for Series and Index when subtype is not str
1 parent b5a9735 commit 2bf306d

File tree

5 files changed

+73
-29
lines changed

5 files changed

+73
-29
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
256256
]
257257
| None
258258
),
259-
) -> Series: ...
259+
) -> UnknownSeries: ...
260260
@overload
261-
def __getitem__(self, idx: tuple[Scalar, slice]) -> Series | _T: ...
261+
def __getitem__(self, idx: tuple[Scalar, slice]) -> UnknownSeries | _T: ...
262262
@overload
263263
def __setitem__(
264264
self,
@@ -288,13 +288,13 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
288288
if sys.version_info >= (3, 12):
289289
class _GetItemHack:
290290
@overload
291-
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
291+
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> UnknownSeries: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
292292
@overload
293293
def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
294294
self, key: Iterable[Hashable] | slice
295295
) -> Self: ...
296296
@overload
297-
def __getitem__(self, key: Hashable) -> Series: ...
297+
def __getitem__(self, key: Hashable) -> UnknownSeries: ...
298298

299299
else:
300300
class _GetItemHack:

pandas-stubs/core/series.pyi

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,54 @@ _ListLike: TypeAlias = (
245245
ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | IndexOpsMixin[S1]
246246
)
247247

248+
class _StrMethods:
249+
@overload
250+
def __get__(self, instance: Series[str], owner: Any) -> StringMethods[
251+
Series[str],
252+
DataFrame,
253+
Series[bool],
254+
Series[list[str]],
255+
Series[int],
256+
Series[bytes],
257+
Series[str],
258+
Series[type[object]],
259+
]: ...
260+
@overload
261+
def __get__(self, instance: Series[bytes], owner: Any) -> StringMethods[
262+
Series[bytes],
263+
DataFrame,
264+
Series[bool],
265+
Series[list[str]],
266+
Series[int],
267+
Series[bytes],
268+
Series[str],
269+
Series[type[object]],
270+
]: ...
271+
@overload
272+
def __get__(self, instance: Series[list[str]], owner: Any) -> StringMethods[
273+
Series[list[str]],
274+
DataFrame,
275+
Series[bool],
276+
Series[list[str]],
277+
Series[int],
278+
Series[bytes],
279+
Series[str],
280+
Series[type[object]],
281+
]: ...
282+
@overload
283+
def __get__(self, instance: Series[S1], owner: Any) -> NoReturn: ...
284+
@overload
285+
def __get__(self, instance: UnknownSeries, owner: Any) -> StringMethods[
286+
Series,
287+
DataFrame,
288+
Series[bool],
289+
Series[list[str]],
290+
Series[int],
291+
Series[bytes],
292+
Series[str],
293+
Series[type[object]],
294+
]: ...
295+
248296
class Series(IndexOpsMixin[S1], NDFrame):
249297
__hash__: ClassVar[None]
250298

@@ -1170,19 +1218,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
11701218
copy: _bool = ...,
11711219
) -> Series[S1]: ...
11721220
def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ...
1173-
@property
1174-
def str(
1175-
self,
1176-
) -> StringMethods[
1177-
Self,
1178-
DataFrame,
1179-
Series[bool],
1180-
Series[list[str]],
1181-
Series[int],
1182-
Series[bytes],
1183-
Series[str],
1184-
Series[type[object]],
1185-
]: ...
1221+
str: _StrMethods
11861222
@property
11871223
def dt(self) -> CombinedDatetimelikeProperties: ...
11881224
@property

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ types-pytz = ">= 2022.1.1"
3535
numpy = ">= 1.23.5"
3636

3737
[tool.poetry.group.dev.dependencies]
38-
mypy = "1.15.0"
38+
mypy = "1.16.0"
3939
pandas = "2.2.3"
4040
pyarrow = ">=10.0.1"
4141
pytest = ">=7.1.2"

tests/test_frame.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,15 +3073,15 @@ def test_set_columns() -> None:
30733073
# https://github.com/python/mypy/issues/3004
30743074
# pyright accepts this, so we only type check for pyright,
30753075
# and also test the code with pytest
3076-
df.columns = ["c", "d"] # type: ignore[assignment]
3077-
df.columns = [1, 2] # type: ignore[assignment]
3078-
df.columns = [1, "a"] # type: ignore[assignment]
3079-
df.columns = np.array([1, 2]) # type: ignore[assignment]
3080-
df.columns = pd.Series([1, 2]) # type: ignore[assignment]
3081-
df.columns = np.array([1, "a"]) # type: ignore[assignment]
3082-
df.columns = pd.Series([1, "a"]) # type: ignore[assignment]
3083-
df.columns = (1, 2) # type: ignore[assignment]
3084-
df.columns = (1, "a") # type: ignore[assignment]
3076+
df.columns = ["c", "d"]
3077+
df.columns = [1, 2]
3078+
df.columns = [1, "a"]
3079+
df.columns = np.array([1, 2])
3080+
df.columns = pd.Series([1, 2])
3081+
df.columns = np.array([1, "a"])
3082+
df.columns = pd.Series([1, "a"])
3083+
df.columns = (1, 2)
3084+
df.columns = (1, "a")
30853085
if TYPE_CHECKING_INVALID_USAGE:
30863086
df.columns = "abc" # type: ignore[assignment] # pyright: ignore[reportAttributeAccessIssue]
30873087

@@ -4373,8 +4373,8 @@ def test_hashable_args() -> None:
43734373
# https://github.com/python/mypy/issues/3004
43744374
# pyright accepts this, so we only type check for pyright,
43754375
# and also test the code with pytest
4376-
df.columns = test # type: ignore[assignment]
4377-
df.columns = ["test"] # type: ignore[assignment]
4376+
df.columns = test
4377+
df.columns = ["test"]
43784378

43794379
testDict = {"test": 1}
43804380
with ensure_clean() as path:

tests/test_string_accessors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing_extensions import assert_type
77

88
from tests import (
9+
TYPE_CHECKING_INVALID_USAGE,
910
check,
1011
np_ndarray_bool,
1112
)
@@ -411,3 +412,10 @@ def test_index_overloads_extract():
411412
pd.Index,
412413
object,
413414
)
415+
416+
417+
def test_series_unknown():
418+
if TYPE_CHECKING_INVALID_USAGE:
419+
s = pd.Series([1, 2, 3])
420+
s.str.startswith("a") # type:ignore[attr-defined]
421+
s.str.slice(2, 4) # type:ignore[attr-defined]

0 commit comments

Comments
 (0)