Skip to content

Commit 2971f01

Browse files
GH1089 Add more tests
1 parent 41f019b commit 2971f01

File tree

2 files changed

+95
-16
lines changed

2 files changed

+95
-16
lines changed

pandas-stubs/core/series.pyi

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,15 +1822,23 @@ class Series(IndexOpsMixin[S1], NDFrame):
18221822
# Met @overload
18231823
@overload
18241824
def add(
1825+
self,
1826+
other: S1 | Self,
1827+
level: Level | None = ...,
1828+
fill_value: float | None = ...,
1829+
axis: int = ...,
1830+
) -> Self: ...
1831+
@overload
1832+
def add( # pyright: ignore[reportOverlappingOverload]
18251833
self: Series[int],
1826-
other: int,
1834+
other: Series[int] | int,
18271835
level: Level | None = ...,
18281836
fill_value: float | None = ...,
18291837
axis: int = ...,
18301838
) -> Series[int]: ...
18311839
@overload
1832-
def add( # pyright: ignore[reportOverlappingOverload]
1833-
self,
1840+
def add(
1841+
self: Series[int] | Series[float],
18341842
other: float | Series[float],
18351843
level: Level | None = ...,
18361844
fill_value: float | None = ...,
@@ -1846,7 +1854,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
18461854
) -> Series[float]: ...
18471855
@overload
18481856
def add(
1849-
self,
1857+
self: Series[complex],
18501858
other: complex,
18511859
level: Level | None = ...,
18521860
fill_value: float | None = ...,
@@ -1860,14 +1868,6 @@ class Series(IndexOpsMixin[S1], NDFrame):
18601868
fill_value: float | None = ...,
18611869
axis: int = ...,
18621870
) -> Series[S1]: ...
1863-
@overload
1864-
def add(
1865-
self,
1866-
other: S1 | Self,
1867-
level: Level | None = ...,
1868-
fill_value: float | None = ...,
1869-
axis: int = ...,
1870-
) -> Series: ...
18711871
def all(
18721872
self,
18731873
axis: AxisIndex = ...,
@@ -2279,6 +2279,14 @@ class Series(IndexOpsMixin[S1], NDFrame):
22792279
**kwargs,
22802280
) -> float: ...
22812281
@overload
2282+
def sub(
2283+
self,
2284+
other: S1 | Self,
2285+
level: Level | None = ...,
2286+
fill_value: float | None = ...,
2287+
axis: int = ...,
2288+
) -> Self: ...
2289+
@overload
22822290
def sub( # pyright: ignore[reportOverlappingOverload]
22832291
self: Series[int],
22842292
other: int,

tests/test_series_arithmetic.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Test module for arithmetic operations on Series."""
22

3-
from typing import cast
4-
53
import numpy as np
64
import pandas as pd
75
from typing_extensions import assert_type
@@ -193,11 +191,12 @@ def test_element_wise_float_float() -> None:
193191

194192

195193
def test_element_wise_int_unknown() -> None:
196-
s = cast(pd.Series, pd.Series([7, -5, 10]))
194+
df = pd.DataFrame({"a": [7, -5, 10]})
195+
s = df["a"]
197196
s2 = pd.Series([0, 1, -105])
198197

199198
check(assert_type(s + s2, pd.Series), pd.Series)
200-
check(assert_type(s.add(s2, fill_value=0), "pd.Series[float]"), pd.Series)
199+
check(assert_type(s.add(s2, fill_value=0), pd.Series), pd.Series)
201200

202201
check(assert_type(s - s2, pd.Series), pd.Series)
203202
check(assert_type(s.sub(s2, fill_value=0), pd.Series), pd.Series)
@@ -208,3 +207,75 @@ def test_element_wise_int_unknown() -> None:
208207
# GH1089 should be the following
209208
check(assert_type(s / s2, "pd.Series[float]"), pd.Series)
210209
check(assert_type(s.div(s2, fill_value=0), "pd.Series[float]"), pd.Series)
210+
211+
212+
def test_element_wise_unknown_int() -> None:
213+
df = pd.DataFrame({"a": [7, -5, 10]})
214+
s = pd.Series([0, 1, -105])
215+
s2 = df["a"]
216+
217+
check(assert_type(s + s2, pd.Series), pd.Series)
218+
check(assert_type(s.add(s2, fill_value=0), pd.Series), pd.Series)
219+
220+
check(assert_type(s - s2, pd.Series), pd.Series)
221+
check(assert_type(s.sub(s2, fill_value=0), pd.Series), pd.Series)
222+
223+
check(assert_type(s * s2, pd.Series), pd.Series)
224+
check(assert_type(s.mul(s2, fill_value=0), pd.Series), pd.Series)
225+
226+
check(assert_type(s / s2, "pd.Series[float]"), pd.Series)
227+
check(assert_type(s.div(s2, fill_value=0), "pd.Series[float]"), pd.Series)
228+
229+
230+
def test_element_wise_unknown_unknown() -> None:
231+
df = pd.DataFrame({"a": [7, -5, 10]})
232+
s = df["a"]
233+
s2 = df["a"]
234+
235+
check(assert_type(s + s2, pd.Series), pd.Series)
236+
check(assert_type(s.add(s2, fill_value=0), pd.Series), pd.Series)
237+
238+
check(assert_type(s - s2, pd.Series), pd.Series)
239+
check(assert_type(s.sub(s2, fill_value=0), pd.Series), pd.Series)
240+
241+
check(assert_type(s * s2, pd.Series), pd.Series)
242+
check(assert_type(s.mul(s2, fill_value=0), pd.Series), pd.Series)
243+
244+
check(assert_type(s / s2, "pd.Series[float]"), pd.Series)
245+
check(assert_type(s.div(s2, fill_value=0), "pd.Series[float]"), pd.Series)
246+
247+
248+
def test_element_wise_float_unknown() -> None:
249+
df = pd.DataFrame({"a": [7, -5, 10]})
250+
s = pd.Series([1.3, 2.5, 4.5])
251+
s2 = df["a"]
252+
253+
check(assert_type(s + s2, pd.Series), pd.Series)
254+
check(assert_type(s.add(s2, fill_value=0), pd.Series), pd.Series)
255+
256+
check(assert_type(s - s2, pd.Series), pd.Series)
257+
check(assert_type(s.sub(s2, fill_value=0), pd.Series), pd.Series)
258+
259+
check(assert_type(s * s2, pd.Series), pd.Series)
260+
check(assert_type(s.mul(s2, fill_value=0), pd.Series), pd.Series)
261+
262+
check(assert_type(s / s2, "pd.Series[float]"), pd.Series)
263+
check(assert_type(s.div(s2, fill_value=0), "pd.Series[float]"), pd.Series)
264+
265+
266+
def test_element_wise_unknown_float() -> None:
267+
df = pd.DataFrame({"a": [7, -5, 10]})
268+
s = df["a"]
269+
s2 = pd.Series([1.3, 2.5, 4.5])
270+
271+
check(assert_type(s + s2, pd.Series), pd.Series)
272+
check(assert_type(s.add(s2, fill_value=0), pd.Series), pd.Series)
273+
274+
check(assert_type(s - s2, pd.Series), pd.Series)
275+
check(assert_type(s.sub(s2, fill_value=0), pd.Series), pd.Series)
276+
277+
check(assert_type(s * s2, pd.Series), pd.Series)
278+
check(assert_type(s.mul(s2, fill_value=0), pd.Series), pd.Series)
279+
280+
check(assert_type(s / s2, "pd.Series[float]"), pd.Series)
281+
check(assert_type(s.div(s2, fill_value=0), "pd.Series[float]"), pd.Series)

0 commit comments

Comments
 (0)