Skip to content

Commit e5b7b1e

Browse files
committed
fix: #718 __add__
1 parent c99cedc commit e5b7b1e

File tree

5 files changed

+56
-99
lines changed

5 files changed

+56
-99
lines changed

pandas-stubs/core/series.pyi

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,13 @@ from pandas.plotting import PlotAccessor
184184

185185
_scalar_timedelta: TypeAlias = timedelta | np.timedelta64 | BaseOffset | Timedelta
186186
_vector_timedelta: TypeAlias = (
187-
Sequence[timedelta] | Sequence[Timedelta] | Series[Timedelta] | TimedeltaIndex
187+
Sequence[timedelta] | Sequence[Timedelta] | TimedeltaIndex
188188
)
189-
_all_timedelta: TypeAlias = _scalar_timedelta | _vector_timedelta
190-
_stamp_and_delta = TypeVar("_stamp_and_delta", bound=Timestamp | Timedelta)
189+
_nonseries_timedelta: TypeAlias = _scalar_timedelta | _vector_timedelta
190+
_all_int: TypeAlias = int | np_ndarray_anyint | Series[int] | Sequence[int]
191+
192+
_T_INT = TypeVar("_T_INT", bound=int)
193+
_T_STAMP_AND_DELTA = TypeVar("_T_STAMP_AND_DELTA", bound=Timestamp | Timedelta)
191194

192195
class _iLocIndexerSeries(_iLocIndexer, Generic[S1]):
193196
# get item
@@ -1593,14 +1596,12 @@ class Series(IndexOpsMixin[S1], NDFrame):
15931596
# just failed to generate these so I couldn't match
15941597
# them up.
15951598
@overload
1596-
def __add__(
1597-
self: Series[int],
1598-
other: int | np_ndarray_anyint | Series[int] | Sequence[int],
1599-
) -> Series[int]: ...
1599+
def __add__(self: Series[_T_INT], other: _all_int) -> Series[_T_INT]: ...
16001600
@overload
16011601
def __add__(
1602-
self: Series[Timestamp], other: _all_timedelta
1603-
) -> Series[Timestamp]: ...
1602+
self: Series[_T_STAMP_AND_DELTA],
1603+
other: _nonseries_timedelta | Series[Timedelta],
1604+
) -> Series[_T_STAMP_AND_DELTA]: ...
16041605
@overload
16051606
def __add__(self: Series[Timedelta], other: Period) -> PeriodSeries: ...
16061607
@overload
@@ -1609,10 +1610,6 @@ class Series(IndexOpsMixin[S1], NDFrame):
16091610
other: datetime | Timestamp | Series[Timestamp] | DatetimeIndex,
16101611
) -> Series[Timestamp]: ...
16111612
@overload
1612-
def __add__(
1613-
self: Series[Timedelta], other: timedelta | Timedelta | np.timedelta64
1614-
) -> Series[Timedelta]: ...
1615-
@overload
16161613
def __add__(self, other: S1 | Self) -> Self: ...
16171614
@overload
16181615
def __add__(
@@ -1625,7 +1622,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
16251622
self, other: bool | list[int] | MaskType
16261623
) -> Series[bool]: ...
16271624
@overload
1628-
def __and__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
1625+
def __and__(self, other: _all_int) -> Series[int]: ...
16291626
# def __array__(self, dtype: Optional[_bool] = ...) -> _np_ndarray
16301627
def __div__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
16311628
def __eq__(self, other: object) -> Series[_bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
@@ -1636,13 +1633,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
16361633
@overload
16371634
def __floordiv__(
16381635
self: Series[Timedelta],
1639-
other: (
1640-
timedelta
1641-
| Series[Timedelta]
1642-
| np.timedelta64
1643-
| TimedeltaIndex
1644-
| Sequence[timedelta]
1645-
),
1636+
other: _nonseries_timedelta | Series[Timedelta],
16461637
) -> Series[int]: ...
16471638
@overload
16481639
def __floordiv__(self, other: num | _ListLike | Series[S1]) -> Series[int]: ...
@@ -1659,21 +1650,25 @@ class Series(IndexOpsMixin[S1], NDFrame):
16591650
self, other: S1 | _ListLike | Series[S1] | datetime | timedelta | date
16601651
) -> Series[_bool]: ...
16611652
@overload
1653+
def __mul__(self: Series[_T_INT], other: _all_int) -> Series[_T_INT]: ...
1654+
@overload
16621655
def __mul__(
1663-
self: Series[Timestamp],
1664-
other: float | Series[int] | Series[float] | Sequence[float],
1665-
) -> Series[Timestamp]: ...
1656+
self: Series[_T_STAMP_AND_DELTA],
1657+
other: (
1658+
num | Sequence[num] | Series[int] | Series[float] | float | Sequence[float]
1659+
),
1660+
) -> Series[_T_STAMP_AND_DELTA]: ...
16661661
@overload
16671662
def __mul__(
1668-
self: Series[Timedelta],
1669-
other: num | Sequence[num] | Series[int] | Series[float],
1670-
) -> Series[Timedelta]: ...
1663+
self: Series[_T_STAMP_AND_DELTA],
1664+
other: _nonseries_timedelta | Series[Timedelta],
1665+
) -> Never: ...
16711666
@overload
16721667
def __mul__(
1673-
self, other: timedelta | Timedelta | Series[Timedelta] | np.timedelta64
1668+
self, other: _nonseries_timedelta | Series[Timedelta]
16741669
) -> Series[Timedelta]: ...
16751670
@overload
1676-
def __mul__(self: Series[S1], other: num | _ListLike | Series) -> Series[S1]: ...
1671+
def __mul__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
16771672
def __mod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
16781673
def __ne__(self, other: object) -> Series[_bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
16791674
def __pow__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
@@ -1683,10 +1678,10 @@ class Series(IndexOpsMixin[S1], NDFrame):
16831678
self, other: bool | list[int] | MaskType
16841679
) -> Series[bool]: ...
16851680
@overload
1686-
def __or__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
1681+
def __or__(self, other: _all_int) -> Series[int]: ...
16871682
@overload
16881683
def __radd__(
1689-
self: Series[Timestamp], other: _all_timedelta
1684+
self: Series[Timestamp], other: _nonseries_timedelta | Series[Timedelta]
16901685
) -> Series[Timestamp]: ...
16911686
@overload
16921687
def __radd__(
@@ -1702,12 +1697,12 @@ class Series(IndexOpsMixin[S1], NDFrame):
17021697
self, other: bool | MaskType | list[int]
17031698
) -> Series[bool]: ...
17041699
@overload
1705-
def __rand__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
1700+
def __rand__(self, other: _all_int) -> Series[int]: ...
17061701
def __rdiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
17071702
def __rdivmod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
17081703
@overload
17091704
def __rfloordiv__(
1710-
self: Series[Timedelta], other: _all_timedelta
1705+
self: Series[Timedelta], other: _nonseries_timedelta | Series[Timedelta]
17111706
) -> Series[int]: ...
17121707
@overload
17131708
def __rfloordiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
@@ -1726,11 +1721,11 @@ class Series(IndexOpsMixin[S1], NDFrame):
17261721
self, other: bool | MaskType | list[int]
17271722
) -> Series[bool]: ...
17281723
@overload
1729-
def __ror__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
1724+
def __ror__(self, other: _all_int) -> Series[int]: ...
17301725
def __rsub__(self, other: num | _ListLike | Series[S1]) -> Series: ...
17311726
@overload
17321727
def __rtruediv__(
1733-
self: Series[Timedelta], other: _all_timedelta
1728+
self: Series[Timedelta], other: _nonseries_timedelta | Series[Timedelta]
17341729
) -> Series[float]: ...
17351730
@overload
17361731
def __rtruediv__(self, other: num | _ListLike | Series[S1] | Path) -> Series: ...
@@ -1740,22 +1735,16 @@ class Series(IndexOpsMixin[S1], NDFrame):
17401735
self, other: bool | MaskType | list[int]
17411736
) -> Series[bool]: ...
17421737
@overload
1743-
def __rxor__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
1738+
def __rxor__(self, other: _all_int) -> Series[int]: ...
17441739
@overload
1745-
def __sub__(
1746-
self: Series[int],
1747-
other: int | np_ndarray_anyint | Series[int] | Sequence[int],
1748-
) -> Series[int]: ...
1740+
def __sub__(self: Series[_T_INT], other: _all_int) -> Series[_T_INT]: ...
17491741
@overload
1750-
def __sub__(
1751-
self: Series[Timestamp], other: _all_timedelta
1742+
def __sub__( # type: ignore[overload-overlap]
1743+
self: Series[Timestamp], other: _nonseries_timedelta | Series[Timedelta]
17521744
) -> Series[Timestamp]: ...
17531745
@overload
17541746
def __sub__(
1755-
self: Series[Timedelta],
1756-
other: (
1757-
timedelta | Timedelta | Series[Timedelta] | TimedeltaIndex | np.timedelta64
1758-
),
1747+
self, other: _nonseries_timedelta | Series[Timedelta]
17591748
) -> Series[Timedelta]: ...
17601749
@overload
17611750
def __sub__(
@@ -1765,26 +1754,22 @@ class Series(IndexOpsMixin[S1], NDFrame):
17651754
def __sub__(self, other: num | _ListLike | Series) -> Series: ...
17661755
@overload
17671756
def __truediv__(
1768-
self: Series[Timestamp],
1757+
self: Series[_T_STAMP_AND_DELTA],
17691758
other: float | Series[int] | Series[float] | Sequence[float],
1770-
) -> Series[Timestamp]: ...
1759+
) -> Series[_T_STAMP_AND_DELTA]: ...
17711760
@overload
17721761
def __truediv__(
1773-
self: Series[Timedelta], other: _all_timedelta
1762+
self: Series[Timedelta], other: _nonseries_timedelta | Series[Timedelta]
17741763
) -> Series[float]: ...
17751764
@overload
1776-
def __truediv__(
1777-
self: Series[Timedelta], other: float | Sequence[float]
1778-
) -> Series[Timedelta]: ...
1779-
@overload
17801765
def __truediv__(self, other: num | _ListLike | Series[S1] | Path) -> Series: ...
17811766
# ignore needed for mypy as we want different results based on the arguments
17821767
@overload # type: ignore[override]
17831768
def __xor__( # pyright: ignore[reportOverlappingOverload]
17841769
self, other: bool | MaskType | list[int]
17851770
) -> Series[bool]: ...
17861771
@overload
1787-
def __xor__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
1772+
def __xor__(self, other: _all_int) -> Series[int]: ...
17881773
def __invert__(self) -> Series[bool]: ...
17891774
# properties
17901775
# @property
@@ -1965,13 +1950,13 @@ class Series(IndexOpsMixin[S1], NDFrame):
19651950
) -> S1: ...
19661951
@overload
19671952
def mean(
1968-
self: Series[_stamp_and_delta],
1953+
self: Series[_T_STAMP_AND_DELTA],
19691954
axis: AxisIndex | None = ...,
19701955
skipna: _bool = ...,
19711956
level: None = ...,
19721957
numeric_only: _bool = ...,
19731958
**kwargs: Any,
1974-
) -> _stamp_and_delta: ...
1959+
) -> _T_STAMP_AND_DELTA: ...
19751960
@overload
19761961
def mean(
19771962
self,

tests/test_frame.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,13 +2805,7 @@ def test_indexslice_getitem():
28052805
def test_compute_values():
28062806
df = pd.DataFrame({"x": [1, 2, 3, 4]})
28072807
s: pd.Series = pd.Series([10, 20, 30, 40])
2808-
check(
2809-
assert_type(
2810-
df["x"] + s.values, pd.Series # pyright: ignore[reportAssertTypeFailure]
2811-
),
2812-
pd.Series,
2813-
np.int64,
2814-
)
2808+
check(assert_type(df["x"] + s.values, pd.Series), pd.Series, np.int64)
28152809

28162810

28172811
# https://github.com/microsoft/python-type-stubs/issues/164
@@ -2822,20 +2816,9 @@ def test_sum_get_add() -> None:
28222816
summer = df.sum(axis=1)
28232817
check(assert_type(summer, pd.Series), pd.Series)
28242818

2825-
check(
2826-
assert_type(s + summer, pd.Series), # pyright: ignore[reportAssertTypeFailure]
2827-
pd.Series,
2828-
)
2829-
check(
2830-
assert_type(s + df["y"], pd.Series), # pyright: ignore[reportAssertTypeFailure]
2831-
pd.Series,
2832-
)
2833-
check(
2834-
assert_type(
2835-
summer + summer, pd.Series # pyright: ignore[reportAssertTypeFailure]
2836-
),
2837-
pd.Series,
2838-
)
2819+
check(assert_type(s + summer, pd.Series), pd.Series)
2820+
check(assert_type(s + df["y"], pd.Series), pd.Series)
2821+
check(assert_type(summer + summer, pd.Series), pd.Series)
28392822

28402823

28412824
def test_getset_untyped() -> None:

tests/test_scalars.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def test_timedelta_add_sub() -> None:
577577
pd.Timedelta,
578578
)
579579
check(
580-
assert_type( # type: ignore [assert-type]
580+
assert_type( # type: ignore[assert-type]
581581
as_timedelta64 + td, # pyright: ignore[reportAssertTypeFailure]
582582
pd.Timedelta,
583583
),
@@ -646,7 +646,7 @@ def test_timedelta_add_sub() -> None:
646646
pd.Timedelta,
647647
)
648648
check(
649-
assert_type( # type: ignore [assert-type]
649+
assert_type( # type: ignore[assert-type]
650650
as_timedelta64 - td, # pyright: ignore[reportAssertTypeFailure]
651651
pd.Timedelta,
652652
),

tests/test_series.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,22 +1589,10 @@ def test_series_min_max_sub_axis() -> None:
15891589
ss = s1 - s2
15901590
sm = s1 * s2
15911591
sd = s1 / s2
1592-
check(
1593-
assert_type(sa, pd.Series), # pyright: ignore[reportAssertTypeFailure]
1594-
pd.Series,
1595-
)
1596-
check(
1597-
assert_type(ss, pd.Series), # pyright: ignore[reportAssertTypeFailure]
1598-
pd.Series,
1599-
)
1600-
check(
1601-
assert_type(sm, pd.Series), # pyright: ignore[reportAssertTypeFailure]
1602-
pd.Series,
1603-
)
1604-
check(
1605-
assert_type(sd, pd.Series), # pyright: ignore[reportAssertTypeFailure]
1606-
pd.Series,
1607-
)
1592+
check(assert_type(sa, pd.Series), pd.Series)
1593+
check(assert_type(ss, pd.Series), pd.Series)
1594+
check(assert_type(sm, pd.Series), pd.Series) # type: ignore[assert-type]
1595+
check(assert_type(sd, pd.Series), pd.Series)
16081596

16091597

16101598
def test_series_index_isin() -> None:
@@ -3897,7 +3885,7 @@ def foo(sf: pd.Series) -> None:
38973885
pass
38983886

38993887
foo(s)
3900-
check(assert_type(s + pd.Series([1]), pd.Series), pd.Series) # type: ignore [assert-type] # pyright: ignore[reportAssertTypeFailure]
3888+
check(assert_type(s + pd.Series([1]), pd.Series), pd.Series)
39013889

39023890

39033891
def test_series_items() -> None:

tests/test_timefuncs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,9 +1474,10 @@ def test_timedelta64_and_arithmatic_operator() -> None:
14741474
check(assert_type((s3 + td), "pd.Series[pd.Timedelta]"), pd.Series, pd.Timedelta)
14751475
check(assert_type((s3 / td), "pd.Series[float]"), pd.Series, float)
14761476
if TYPE_CHECKING_INVALID_USAGE:
1477-
r1 = s1 * td
1478-
r2 = s1 / td # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
1479-
r3 = s3 * td
1477+
s1.__mul__(td)
1478+
r1 = s1 * td # pyright: ignore[reportOperatorIssue]
1479+
r2 = s1 / td # pyright: ignore[reportOperatorIssue]
1480+
r3 = s3 * td # pyright: ignore[reportOperatorIssue]
14801481

14811482

14821483
def test_timedeltaseries_add_timestampseries() -> None:

0 commit comments

Comments
 (0)