Skip to content

Commit 01130e1

Browse files
GH1089 Partial typehinting
1 parent 65aa064 commit 01130e1

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

pandas-stubs/core/series.pyi

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,11 @@ class Series(IndexOpsMixin[S1], NDFrame):
16091609
@overload
16101610
def __add__(self, other: S1 | Self) -> Self: ...
16111611
@overload
1612+
def __add__(
1613+
self,
1614+
other: complex,
1615+
) -> Series[complex]: ...
1616+
@overload
16121617
def __add__(
16131618
self,
16141619
other: num | _str | timedelta | Timedelta | _ListLike | Series | np.timedelta64,
@@ -1716,6 +1721,16 @@ class Series(IndexOpsMixin[S1], NDFrame):
17161721
self, other: Timestamp | datetime | TimestampSeries
17171722
) -> TimedeltaSeries: ...
17181723
@overload
1724+
def __sub__(
1725+
self: Series[int],
1726+
other: int,
1727+
) -> Series[int]: ...
1728+
@overload
1729+
def __sub__(
1730+
self,
1731+
other: complex,
1732+
) -> Series[complex]: ...
1733+
@overload
17191734
def __sub__(self, other: num | _ListLike | Series) -> Series: ...
17201735
def __truediv__(self, other: num | _ListLike | Series[S1] | Path) -> Series: ...
17211736
# ignore needed for mypy as we want different results based on the arguments
@@ -1742,6 +1757,23 @@ class Series(IndexOpsMixin[S1], NDFrame):
17421757
@property
17431758
def loc(self) -> _LocIndexerSeries[S1]: ...
17441759
# Methods
1760+
@overload
1761+
def add(
1762+
self: Series[int],
1763+
other: int,
1764+
level: Level | None = ...,
1765+
fill_value: float | None = ...,
1766+
axis: int = ...,
1767+
) -> Series[int]: ...
1768+
@overload
1769+
def add(
1770+
self,
1771+
other: complex,
1772+
level: Level | None = ...,
1773+
fill_value: float | None = ...,
1774+
axis: int = ...,
1775+
) -> Series[complex]: ...
1776+
@overload
17451777
def add(
17461778
self,
17471779
other: Series[S1] | Scalar,
@@ -2085,6 +2117,31 @@ class Series(IndexOpsMixin[S1], NDFrame):
20852117
numeric_only: _bool = ...,
20862118
**kwargs,
20872119
) -> float: ...
2120+
@overload
2121+
def sub(
2122+
self: Series[int],
2123+
other: int,
2124+
level: Level | None = ...,
2125+
fill_value: float | None = ...,
2126+
axis: AxisIndex | None = ...,
2127+
) -> Series[int]: ...
2128+
@overload
2129+
def sub(
2130+
self: Series[int],
2131+
other: float,
2132+
level: Level | None = ...,
2133+
fill_value: float | None = ...,
2134+
axis: AxisIndex | None = ...,
2135+
) -> Series[int]: ...
2136+
@overload
2137+
def sub(
2138+
self,
2139+
other: complex,
2140+
level: Level | None = ...,
2141+
fill_value: float | None = ...,
2142+
axis: AxisIndex | None = ...,
2143+
) -> Series[complex]: ...
2144+
@overload
20882145
def sub(
20892146
self,
20902147
other: num | _ListLike | Series[S1],

tests/test_series.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def test_types_mean() -> None:
450450
check(assert_type(s.mean(), float), float)
451451
check(
452452
assert_type(
453-
s.groupby(level=0).mean(), # pyright: ignore[reportAssertTypeFailure]
453+
s.groupby(level=0).mean(),
454454
"pd.Series[float]",
455455
),
456456
pd.Series,
@@ -465,7 +465,7 @@ def test_types_median() -> None:
465465
check(assert_type(s.median(), float), float)
466466
check(
467467
assert_type(
468-
s.groupby(level=0).median(), # pyright: ignore[reportAssertTypeFailure]
468+
s.groupby(level=0).median(),
469469
"pd.Series[float]",
470470
),
471471
pd.Series,
@@ -690,7 +690,7 @@ def test_types_scalar_arithmetic() -> None:
690690
check(assert_type(s + 1, "pd.Series[int]"), pd.Series, np.integer)
691691
check(assert_type(s.add(1, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
692692

693-
check(assert_type(s - 1, pd.Series), pd.Series, np.integer)
693+
check(assert_type(s - 1, "pd.Series[int]"), pd.Series, np.integer)
694694
check(assert_type(s.sub(1, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
695695

696696
check(assert_type(s * 2, pd.Series), pd.Series, np.integer)
@@ -721,8 +721,10 @@ def test_types_scalar_arithmetic() -> None:
721721
def test_types_complex_arithmetic() -> None:
722722
c = 1 + 1j
723723
s = pd.Series([1.0, 2.0, 3.0])
724-
check(assert_type(s + c, pd.Series), pd.Series)
725-
check(assert_type(s - c, pd.Series), pd.Series)
724+
check(assert_type(s + c, "pd.Series[complex]"), pd.Series, complex)
725+
check(assert_type(s.add(c), "pd.Series[complex]"), pd.Series, complex)
726+
check(assert_type(s - c, "pd.Series[complex]"), pd.Series, complex)
727+
check(assert_type(s.sub(c), "pd.Series[complex]"), pd.Series, complex)
726728

727729

728730
def test_types_groupby() -> None:
@@ -1368,7 +1370,7 @@ def test_series_mul() -> None:
13681370
sm = s * 4
13691371
check(assert_type(sm, pd.Series), pd.Series)
13701372
ss = s - 4
1371-
check(assert_type(ss, pd.Series), pd.Series)
1373+
check(assert_type(ss, "pd.Series[int]"), pd.Series, np.integer)
13721374
sm2 = s * s
13731375
check(assert_type(sm2, pd.Series), pd.Series)
13741376
sp = s + 4

0 commit comments

Comments
 (0)