Skip to content

Commit a0a4078

Browse files
Add @ operator type hints for Series
1 parent 779aab6 commit a0a4078

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

pandas-stubs/core/series.pyi

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,8 +798,18 @@ class Series(IndexOpsMixin[S1], NDFrame):
798798
def dot(
799799
self, other: ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | Index[S1]
800800
) -> np.ndarray: ...
801-
def __matmul__(self, other): ...
802-
def __rmatmul__(self, other): ...
801+
@overload
802+
def __matmul__(self, other: Series) -> Scalar: ...
803+
@overload
804+
def __matmul__(self, other: DataFrame) -> Series: ...
805+
@overload
806+
def __matmul__(self, other: np.ndarray) -> np.ndarray: ...
807+
@overload
808+
def __rmatmul__(self, other: Series) -> Scalar: ...
809+
@overload
810+
def __rmatmul__(self, other: DataFrame) -> Series: ...
811+
@overload
812+
def __rmatmul__(self, other: np.ndarray) -> np.ndarray: ...
803813
@overload
804814
def searchsorted(
805815
self,

tests/test_series.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,16 +1236,17 @@ def test_types_as_type() -> None:
12361236

12371237

12381238
def test_types_dot() -> None:
1239+
"""Test typing of multiplication methods (dot and @) for Series."""
12391240
s1 = pd.Series([0, 1, 2, 3])
12401241
s2 = pd.Series([-1, 2, -3, 4])
12411242
df1 = pd.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]])
12421243
n1 = np.array([[0, 1], [1, 2], [-1, -1], [2, 0]])
12431244
sc1: Scalar = s1.dot(s2)
12441245
sc2: Scalar = s1 @ s2
1245-
s3: pd.Series = s1.dot(df1)
1246-
s4: pd.Series = s1 @ df1
1247-
n2: np.ndarray = s1.dot(n1)
1248-
n3: np.ndarray = s1 @ n1
1246+
check(assert_type(s1.dot(df1), "pd.Series[int]"),pd.Series, np.int64)
1247+
check(assert_type(s1 @ df1, pd.Series),pd.Series )
1248+
check(assert_type(s1.dot(n1), np.ndarray),np.ndarray )
1249+
check(assert_type(s1 @ n1, np.ndarray),np.ndarray )
12491250

12501251

12511252
def test_series_loc_setitem() -> None:

0 commit comments

Comments
 (0)