Skip to content

Commit 7a87d18

Browse files
committed
fix: type asserts
1 parent 0657417 commit 7a87d18

File tree

6 files changed

+137
-25
lines changed

6 files changed

+137
-25
lines changed

attempt.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import assert_type, reveal_type
2+
import pandas as pd
3+
from pandas.core.series import TimedeltaSeries # noqa: F401
4+
import numpy as np
5+
import datetime as dt
6+
7+
from tests import check
8+
9+
10+
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [5, 4, 3, 2, 1]})
11+
s1 = df.min(axis=1)
12+
s2 = df.max(axis=1)
13+
sa = s1 + s2
14+
ss = s1 - s2
15+
sm = s1 * s2
16+
sd = s1 / s2
17+
check(assert_type(sa, pd.Series), pd.Series)
18+
reveal_type(s1.__sub__(s2))
19+
reveal_type(s2.__rsub__(s1))
20+
check(assert_type(ss, pd.Series), pd.Series)
21+
check(assert_type(sm, pd.Series), pd.Series)
22+
check(assert_type(sd, pd.Series), pd.Series)
23+
24+
ts1 = pd.to_datetime(pd.Series(["2022-03-05", "2022-03-06"]))
25+
assert isinstance(ts1.iloc[0], pd.Timestamp)
26+
td1 = pd.to_timedelta([2, 3], "seconds")
27+
ts2 = pd.to_datetime(pd.Series(["2022-03-08", "2022-03-10"]))
28+
r1 = ts1 - ts2
29+
check(assert_type(r1, "TimedeltaSeries"), pd.Series, pd.Timedelta)
30+
r2 = r1 / td1
31+
check(assert_type(r2, "pd.Series[float]"), pd.Series, float)
32+
r3 = r1 - td1
33+
check(assert_type(r3, "TimedeltaSeries"), pd.Series, pd.Timedelta)
34+
r4 = pd.Timedelta(5, "days") / r1
35+
check(assert_type(r4, "pd.Series[float]"), pd.Series, float)
36+
sb = pd.Series([1, 2]) == pd.Series([1, 3])
37+
check(assert_type(sb, "pd.Series[bool]"), pd.Series, np.bool_)
38+
r5 = sb * r1
39+
check(assert_type(r5, "TimedeltaSeries"), pd.Series, pd.Timedelta)
40+
r6 = r1 * 4
41+
check(assert_type(r6, "TimedeltaSeries"), pd.Series, pd.Timedelta)
42+
43+
tsp1 = pd.Timestamp("2022-03-05")
44+
dt1 = dt.datetime(2022, 9, 1, 12, 5, 30)
45+
r7 = ts1 - tsp1
46+
check(assert_type(r7, "TimedeltaSeries"), pd.Series, pd.Timedelta)
47+
r8 = ts1 - dt1
48+
check(assert_type(r8, "TimedeltaSeries"), pd.Series, pd.Timedelta)

pandas-stubs/_libs/tslibs/timestamps.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ class Timestamp(datetime, SupportsIndex):
246246
@overload
247247
def __sub__(self, other: TimedeltaSeries) -> Series[Timestamp]: ...
248248
@overload
249-
def __sub__(self, other: Series[Never]) -> Series: ... # type: ignore[overload-overlap]
249+
def __sub__(self, other: Series[Never]) -> Series: ...
250250
@overload
251-
def __sub__(self, other: Series[Timestamp]) -> TimedeltaSeries: ...
251+
def __sub__(self, other: Series[Timestamp]) -> Series[Timedelta]: ...
252252
@overload
253253
def __sub__(
254254
self, other: np_ndarray[ShapeT, np.timedelta64]

pandas-stubs/core/series.pyi

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,9 +1624,9 @@ class Series(IndexOpsMixin[S1], NDFrame):
16241624
# just failed to generate these so I couldn't match
16251625
# them up.
16261626
@overload
1627-
def __add__(self: Series[Never], other: Scalar | _ListLike | Series) -> Series: ...
1627+
def __add__(self: Series[Never], other: Scalar | _ListLike | Series) -> Series: ... # type: ignore[overload-overlap]
16281628
@overload
1629-
def __add__(self, other: Series[Never]) -> Series: ...
1629+
def __add__(self, other: Series[Never]) -> Series: ... # type: ignore[overload-overlap]
16301630
@overload
16311631
def __add__(
16321632
self: Series[bool],
@@ -1709,6 +1709,15 @@ class Series(IndexOpsMixin[S1], NDFrame):
17091709
other: datetime | np.datetime64 | np_ndarray_dt | Series[Timestamp],
17101710
) -> Series[Timestamp]: ...
17111711
@overload
1712+
def __add__(
1713+
self: Series[Timedelta],
1714+
other: timedelta | np.timedelta64 | np_ndarray_td | TimedeltaSeries,
1715+
) -> TimedeltaSeries: ...
1716+
@overload
1717+
def __add__(
1718+
self: Series[Timedelta], other: Series[Timedelta]
1719+
) -> Series[Timedelta]: ...
1720+
@overload
17121721
def __add__(self: Series[Timedelta], other: Period) -> PeriodSeries: ...
17131722
@overload
17141723
def add(
@@ -2235,16 +2244,38 @@ class Series(IndexOpsMixin[S1], NDFrame):
22352244
other: timedelta | np.timedelta64 | np_ndarray_td | TimedeltaSeries,
22362245
) -> TimedeltaSeries: ...
22372246
@overload
2247+
def __mul__(self: Series[bool], other: Series[Timedelta]) -> Series[Timedelta]: ... # type: ignore[overload-overlap]
2248+
@overload
22382249
def __mul__(
22392250
self: Series[int],
22402251
other: timedelta | np.timedelta64 | np_ndarray_td | TimedeltaSeries,
22412252
) -> TimedeltaSeries: ...
22422253
@overload
2254+
def __mul__(self: Series[int], other: Series[Timedelta]) -> Series[Timedelta]: ...
2255+
@overload
22432256
def __mul__(
22442257
self: Series[float],
22452258
other: timedelta | np.timedelta64 | np_ndarray_td | TimedeltaSeries,
22462259
) -> TimedeltaSeries: ...
22472260
@overload
2261+
def __mul__(self: Series[float], other: Series[Timedelta]) -> Series[Timedelta]: ...
2262+
@overload
2263+
def __mul__(
2264+
self: Series[Timedelta],
2265+
other: (
2266+
float
2267+
| Sequence[float]
2268+
| np_ndarray_bool
2269+
| np_ndarray_anyint
2270+
| np_ndarray_float
2271+
),
2272+
) -> TimedeltaSeries: ...
2273+
@overload
2274+
def __mul__(
2275+
self: Series[Timedelta],
2276+
other: Series[bool] | Series[int] | Series[float],
2277+
) -> Series[Timedelta]: ...
2278+
@overload
22482279
def mul(
22492280
self: Series[Never],
22502281
other: complex | _ListLike | Series,
@@ -2436,19 +2467,49 @@ class Series(IndexOpsMixin[S1], NDFrame):
24362467
self: Series[_T_COMPLEX], other: np_ndarray_complex
24372468
) -> Series[complex]: ...
24382469
@overload
2439-
def __rmul__(
2470+
def __rmul__( # type: ignore[misc]
24402471
self: Series[bool],
2441-
other: timedelta | np.timedelta64 | np_ndarray_td | TimedeltaSeries,
2472+
other: (
2473+
timedelta
2474+
| np.timedelta64
2475+
| np_ndarray_td
2476+
| Series[Timedelta]
2477+
| TimedeltaSeries
2478+
),
24422479
) -> TimedeltaSeries: ...
24432480
@overload
2444-
def __rmul__(
2481+
def __rmul__( # type: ignore[misc]
24452482
self: Series[int],
2446-
other: timedelta | np.timedelta64 | np_ndarray_td | TimedeltaSeries,
2483+
other: (
2484+
timedelta
2485+
| np.timedelta64
2486+
| np_ndarray_td
2487+
| Series[Timedelta]
2488+
| TimedeltaSeries
2489+
),
24472490
) -> TimedeltaSeries: ...
24482491
@overload
2449-
def __rmul__(
2492+
def __rmul__( # type: ignore[misc]
24502493
self: Series[float],
2451-
other: timedelta | np.timedelta64 | np_ndarray_td | TimedeltaSeries,
2494+
other: (
2495+
timedelta
2496+
| np.timedelta64
2497+
| np_ndarray_td
2498+
| Series[Timedelta]
2499+
| TimedeltaSeries
2500+
),
2501+
) -> TimedeltaSeries: ...
2502+
@overload
2503+
def __rmul__(
2504+
self: Series[Timedelta],
2505+
other: (
2506+
float
2507+
| Sequence[float]
2508+
| np_ndarray_bool
2509+
| np_ndarray_anyint
2510+
| np_ndarray_float
2511+
| Series[_T_INT]
2512+
),
24522513
) -> TimedeltaSeries: ...
24532514
@overload
24542515
def rmul(
@@ -2627,16 +2688,17 @@ class Series(IndexOpsMixin[S1], NDFrame):
26272688
@overload
26282689
def __rxor__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
26292690
@overload
2630-
def __sub__(self, other: Series[Never]) -> Series: ... # type: ignore[overload-overlap]
2691+
def __sub__(self: Series[Never], other: Series[Never]) -> Series: ... # type: ignore[overload-overlap]
26312692
@overload
26322693
def __sub__(
2633-
self: Series[Never],
2634-
other: datetime | np.datetime64 | np_ndarray_dt | Series[Timestamp],
2694+
self: Series[Never], other: datetime | np.datetime64 | np_ndarray_dt
26352695
) -> TimedeltaSeries: ...
26362696
@overload
2637-
def __sub__( # type: ignore[overload-overlap]
2638-
self: Series[Never], other: complex | _ListLike | Series
2639-
) -> Series: ...
2697+
def __sub__(self: Series[Never], other: Series[Timestamp]) -> Series[Timedelta]: ...
2698+
@overload
2699+
def __sub__(self: Series[Never], other: complex | _ListLike | Series) -> Series: ...
2700+
@overload
2701+
def __sub__(self, other: Series[Never]) -> Series: ... # type: ignore[overload-overlap]
26402702
@overload
26412703
def __sub__(
26422704
self: Series[bool],
@@ -2702,18 +2764,20 @@ class Series(IndexOpsMixin[S1], NDFrame):
27022764
) -> Series[complex]: ...
27032765
@overload
27042766
def __sub__(
2705-
self: Series[Timestamp],
2706-
other: datetime | np.datetime64 | np_ndarray_dt | Series[Timestamp],
2767+
self: Series[Timestamp], other: datetime | np.datetime64 | np_ndarray_dt
27072768
) -> TimedeltaSeries: ...
27082769
@overload
2770+
def __sub__(
2771+
self: Series[Timestamp], other: Series[Timestamp]
2772+
) -> Series[Timedelta]: ...
2773+
@overload
27092774
def __sub__(
27102775
self: Series[Timestamp],
27112776
other: (
27122777
timedelta
27132778
| np.timedelta64
27142779
| np_ndarray_td
27152780
| TimedeltaIndex
2716-
| Series[Timedelta]
27172781
| TimedeltaSeries
27182782
| BaseOffset
27192783
),
@@ -2726,7 +2790,6 @@ class Series(IndexOpsMixin[S1], NDFrame):
27262790
| np.timedelta64
27272791
| np_ndarray_td
27282792
| TimedeltaIndex
2729-
| Series[Timedelta]
27302793
| TimedeltaSeries
27312794
),
27322795
) -> TimedeltaSeries: ...

tests/series/arithmetic/test_sub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_sub_pd_datetime() -> None:
180180
a = pd.Series([s + pd.Timedelta(minutes=m) for m in range(3)])
181181

182182
check(assert_type(left_ts - s, "TimedeltaSeries"), pd.Series, pd.Timedelta)
183-
check(assert_type(left_ts - a, "TimedeltaSeries"), pd.Series, pd.Timedelta)
183+
check(assert_type(left_ts - a, "pd.Series[pd.Timedelta]"), pd.Series, pd.Timedelta)
184184

185185
check(assert_type(s - left_ts, pd.Series), pd.Series, pd.Timedelta)
186186
check(assert_type(a - left_ts, pd.Series), pd.Series, pd.Timedelta)

tests/series/test_series.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,7 @@ def test_series_min_max_sub_axis() -> None:
15981598
ss = s1 - s2
15991599
sm = s1 * s2
16001600
sd = s1 / s2
1601+
s1.__sub__(s2)
16011602
check(assert_type(sa, pd.Series), pd.Series)
16021603
check(assert_type(ss, pd.Series), pd.Series)
16031604
check(assert_type(sm, pd.Series), pd.Series)

tests/test_timefuncs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_timestamp_timedelta_series_arithmetic() -> None:
191191
td1 = pd.to_timedelta([2, 3], "seconds")
192192
ts2 = pd.to_datetime(pd.Series(["2022-03-08", "2022-03-10"]))
193193
r1 = ts1 - ts2
194-
check(assert_type(r1, "TimedeltaSeries"), pd.Series, pd.Timedelta)
194+
check(assert_type(r1, "pd.Series[pd.Timedelta]"), pd.Series, pd.Timedelta)
195195
r2 = r1 / td1
196196
check(assert_type(r2, "pd.Series[float]"), pd.Series, float)
197197
r3 = r1 - td1
@@ -201,7 +201,7 @@ def test_timestamp_timedelta_series_arithmetic() -> None:
201201
sb = pd.Series([1, 2]) == pd.Series([1, 3])
202202
check(assert_type(sb, "pd.Series[bool]"), pd.Series, np.bool_)
203203
r5 = sb * r1
204-
check(assert_type(r5, "TimedeltaSeries"), pd.Series, pd.Timedelta)
204+
check(assert_type(r5, "pd.Series[pd.Timedelta]"), pd.Series, pd.Timedelta)
205205
r6 = r1 * 4
206206
check(assert_type(r6, "TimedeltaSeries"), pd.Series, pd.Timedelta)
207207

@@ -1655,7 +1655,7 @@ def test_timedelta64_and_arithmatic_operator() -> None:
16551655
s1 = pd.Series(data=pd.date_range("1/1/2020", "2/1/2020"))
16561656
s2 = pd.Series(data=pd.date_range("1/1/2021", "2/1/2021"))
16571657
s3 = s2 - s1
1658-
check(assert_type(s3, "TimedeltaSeries"), pd.Series, pd.Timedelta)
1658+
check(assert_type(s3, "pd.Series[pd.Timedelta]"), pd.Series, pd.Timedelta)
16591659
td1 = pd.Timedelta(1, "D")
16601660
check(assert_type(s2 - td1, "pd.Series[pd.Timestamp]"), pd.Series, pd.Timestamp)
16611661
# GH 758
@@ -1808,7 +1808,7 @@ def test_timestamp_sub_series() -> None:
18081808
ts1 = pd.to_datetime(pd.Series(["2022-03-05", "2022-03-06"]))
18091809
one_ts = ts1.iloc[0]
18101810
check(assert_type(ts1.iloc[0], pd.Timestamp), pd.Timestamp)
1811-
check(assert_type(one_ts - ts1, "TimedeltaSeries"), pd.Series, pd.Timedelta)
1811+
check(assert_type(one_ts - ts1, "pd.Series[pd.Timedelta]"), pd.Series, pd.Timedelta)
18121812

18131813

18141814
def test_creating_date_range() -> None:

0 commit comments

Comments
 (0)