Skip to content

Commit 3e56625

Browse files
GH1089 Partial typehinting
1 parent ec20f77 commit 3e56625

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

pandas-stubs/core/series.pyi

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,12 +1628,6 @@ class Series(IndexOpsMixin[S1], NDFrame):
16281628
self, other: int | np_ndarray_anyint | Series[int]
16291629
) -> Series[int]: ...
16301630
# def __array__(self, dtype: Optional[_bool] = ...) -> _np_ndarray
1631-
@overload
1632-
def __div__(self: Series[int], other: Series[int]) -> Series[float]: ...
1633-
@overload
1634-
def __div__(self: Series[int], other: int) -> Series[float]: ...
1635-
@overload
1636-
def __div__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
16371631
def __eq__(self, other: object) -> Series[_bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
16381632
def __floordiv__(self, other: num | _ListLike | Series[S1]) -> Series[int]: ...
16391633
def __ge__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
@@ -1649,15 +1643,19 @@ class Series(IndexOpsMixin[S1], NDFrame):
16491643
self, other: S1 | _ListLike | Series[S1] | datetime | timedelta | date
16501644
) -> Series[_bool]: ...
16511645
@overload
1646+
def __mul__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
1647+
self, other: S1 | Self
1648+
) -> Self: ...
1649+
@overload
16521650
def __mul__(
16531651
self, other: timedelta | Timedelta | TimedeltaSeries | np.timedelta64
16541652
) -> TimedeltaSeries: ...
16551653
@overload
1656-
def __mul__(self: Series[int], other: int) -> Series[int]: ...
1657-
@overload
1658-
def __mul__(self: Series[int], other: Series[int]) -> Series[int]: ...
1654+
def __mul__( # pyright: ignore[reportOverlappingOverload]
1655+
self: Series[int], other: Series[int] | int
1656+
) -> Series[int]: ...
16591657
@overload
1660-
def __mul__(self: Series[int], other: Series[float]) -> Series[float]: ...
1658+
def __mul__(self: Series[int], other: Series[float] | float) -> Series[float]: ...
16611659
@overload
16621660
def __mul__(self: Series[Any], other: Series[Any]) -> Series: ...
16631661
@overload
@@ -1687,12 +1685,6 @@ class Series(IndexOpsMixin[S1], NDFrame):
16871685
def __rand__( # pyright: ignore[reportIncompatibleMethodOverride]
16881686
self, other: int | np_ndarray_anyint | Series[int]
16891687
) -> Series[int]: ...
1690-
@overload
1691-
def __rdiv__(self: Series[int], other: int) -> Series[float]: ...
1692-
@overload
1693-
def __rdiv__(self: Series[int], other: Series[int]) -> Series[float]: ...
1694-
@overload
1695-
def __rdiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
16961688
def __rdivmod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
16971689
def __rfloordiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
16981690
def __rmod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
@@ -1741,16 +1733,23 @@ class Series(IndexOpsMixin[S1], NDFrame):
17411733
@overload
17421734
def __sub__(
17431735
self: Series[int],
1744-
other: int,
1736+
other: int | Series[int],
17451737
) -> Series[int]: ...
17461738
@overload
17471739
def __sub__(
17481740
self,
17491741
other: complex,
17501742
) -> Series[complex]: ...
17511743
@overload
1744+
def __sub__(self, other: S1 | Self) -> Self: ...
1745+
@overload
17521746
def __sub__(self, other: num | _ListLike | Series) -> Series: ...
1753-
def __truediv__(self, other: num | _ListLike | Series[S1] | Path) -> Series: ...
1747+
@overload
1748+
def __truediv__(self: Series[int], other: Series[int] | int) -> Series[float]: ...
1749+
@overload
1750+
def __truediv__(
1751+
self, other: num | _ListLike | Series[S1] | Path
1752+
) -> Series | Self: ...
17541753
# ignore needed for mypy as we want different results based on the arguments
17551754
@overload # type: ignore[override]
17561755
def __xor__( # pyright: ignore[reportOverlappingOverload]
@@ -1956,9 +1955,9 @@ class Series(IndexOpsMixin[S1], NDFrame):
19561955
@overload
19571956
def mul(
19581957
self: Series[int],
1959-
other: Series[int],
1958+
other: Series[int] | int,
19601959
level: Level | None = ...,
1961-
fill_value: float | None = ...,
1960+
fill_value: int | None = ...,
19621961
axis: AxisIndex | None = ...,
19631962
) -> Series[int]: ...
19641963
@overload
@@ -2152,7 +2151,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
21522151
**kwargs,
21532152
) -> float: ...
21542153
@overload
2155-
def sub(
2154+
def sub( # pyright: ignore[reportOverlappingOverload]
21562155
self: Series[int],
21572156
other: int,
21582157
level: Level | None = ...,
@@ -2166,7 +2165,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
21662165
level: Level | None = ...,
21672166
fill_value: float | None = ...,
21682167
axis: AxisIndex | None = ...,
2169-
) -> Series[int]: ...
2168+
) -> Series[float]: ...
21702169
@overload
21712170
def sub(
21722171
self,

tests/test_series.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -653,15 +653,14 @@ def test_types_element_wise_arithmetic() -> None:
653653
check(assert_type(s + s2, "pd.Series[int]"), pd.Series, np.integer)
654654
check(assert_type(s.add(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
655655

656-
check(assert_type(s - s2, pd.Series), pd.Series, np.integer)
656+
check(assert_type(s - s2, "pd.Series[int]"), pd.Series, np.integer)
657657
check(assert_type(s.sub(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
658658

659659
check(assert_type(s * s2, "pd.Series[int]"), pd.Series, np.integer)
660660
check(assert_type(s.mul(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
661661

662662
# GH1089 should be the following
663-
# check(assert_type(s / s2, "pd.Series[float]"), pd.Series, np.float64)
664-
check(assert_type(s / s2, "pd.Series"), pd.Series, np.float64)
663+
check(assert_type(s / s2, "pd.Series[float]"), pd.Series, np.float64)
665664
check(
666665
assert_type(s.div(s2, fill_value=0), "pd.Series[float]"), pd.Series, np.float64
667666
)
@@ -696,11 +695,9 @@ def test_types_scalar_arithmetic() -> None:
696695
check(assert_type(s.sub(1, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
697696

698697
check(assert_type(s * 2, "pd.Series[int]"), pd.Series, np.integer)
699-
check(assert_type(s.mul(2, fill_value=0), pd.Series), pd.Series, np.integer)
698+
check(assert_type(s.mul(2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)
700699

701-
# GH1089 should be
702-
# check(assert_type(s / 2, "pd.Series[float]"), pd.Series, np.float64)
703-
check(assert_type(s / 2, pd.Series), pd.Series, np.float64)
700+
check(assert_type(s / 2, "pd.Series[float]"), pd.Series, np.float64)
704701
check(
705702
assert_type(s.div(2, fill_value=0), "pd.Series[float]"), pd.Series, np.float64
706703
)
@@ -1312,10 +1309,12 @@ def test_types_dot() -> None:
13121309
s1 = pd.Series([0, 1, 2, 3])
13131310
s2 = pd.Series([-1, 2, -3, 4])
13141311
df1 = pd.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]])
1312+
df2 = pd.DataFrame([[0.0, 1.0], [-2.0, 3.0], [4.0, -5.0], [6.0, 7.0]])
13151313
n1 = np.array([[0, 1], [1, 2], [-1, -1], [2, 0]])
13161314
check(assert_type(s1.dot(s2), Scalar), np.integer)
13171315
check(assert_type(s1 @ s2, Scalar), np.integer)
13181316
check(assert_type(s1.dot(df1), pd.Series), pd.Series, np.integer)
1317+
check(assert_type(s1.dot(df2), pd.Series), pd.Series, np.float64)
13191318
check(assert_type(s1 @ df1, pd.Series), pd.Series)
13201319
check(assert_type(s1.dot(n1), np.ndarray), np.ndarray)
13211320
check(assert_type(s1 @ n1, np.ndarray), np.ndarray)
@@ -1336,10 +1335,15 @@ def test_series_min_max_sub_axis() -> None:
13361335
sm = s1 * s2
13371336
sd = s1 / s2
13381337
check(assert_type(sa, pd.Series), pd.Series)
1339-
check(assert_type(ss, pd.Series), pd.Series)
1340-
# TODO GH1089 This should not match to Series[int]
1341-
check(assert_type(sm, pd.Series), pd.Series, np.integer) # pyright: ignore[reportAssertTypeFailure]
1342-
check(assert_type(sd, pd.Series), pd.Series)
1338+
check(
1339+
assert_type(ss, pd.Series), # pyright: ignore[reportAssertTypeFailure]
1340+
pd.Series,
1341+
)
1342+
check(assert_type(sm, pd.Series), pd.Series)
1343+
check(
1344+
assert_type(sd, pd.Series), # pyright: ignore[reportAssertTypeFailure]
1345+
pd.Series,
1346+
)
13431347

13441348

13451349
def test_series_index_isin() -> None:

0 commit comments

Comments
 (0)