Skip to content

Commit ad4dc9d

Browse files
committed
make ShapeT work for eq and ne
1 parent 1db78e8 commit ad4dc9d

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

pandas-stubs/_libs/tslibs/timestamps.pyi

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ from pandas._typing import (
4545
TimeUnit,
4646
np_1darray,
4747
np_ndarray,
48-
npt,
4948
)
5049

5150
_Ambiguous: TypeAlias = bool | Literal["raise", "NaT"]
@@ -260,7 +259,7 @@ class Timestamp(datetime, SupportsIndex):
260259
@overload
261260
def __eq__(self, other: Index) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
262261
@overload # TODO: using shape-aware arrays similar to other methods doesn't work in mypy
263-
def __eq__(self, other: npt.NDArray[np.datetime64]) -> npt.NDArray[np.bool]: ... # type: ignore[overload-overlap]
262+
def __eq__(self, other: np_ndarray[ShapeT, np.datetime64]) -> np_ndarray[ShapeT, np.bool]: ... # type: ignore[overload-overlap]
264263
@overload
265264
def __eq__(self, other: object) -> Literal[False]: ...
266265
@overload
@@ -270,7 +269,7 @@ class Timestamp(datetime, SupportsIndex):
270269
@overload
271270
def __ne__(self, other: Index) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
272271
@overload # TODO: using shape-aware arrays similar to other methods doesn't work in mypy
273-
def __ne__(self, other: npt.NDArray[np.datetime64]) -> npt.NDArray[np.bool]: ... # type: ignore[overload-overlap]
272+
def __ne__(self, other: np_ndarray[ShapeT, np.datetime64]) -> np_ndarray[ShapeT, np.bool]: ... # type: ignore[overload-overlap]
274273
@overload
275274
def __ne__(self, other: object) -> Literal[True]: ...
276275
def __hash__(self) -> int: ...

tests/test_scalars.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,20 +1394,28 @@ def test_timestamp_cmp() -> None:
13941394
# tests in this block fail with mypy on Python 3.10 in CI only
13951395
# I couldn't reproduce the failure locally so skip mypy on Python 3.10
13961396
eq1_arr = check(
1397-
assert_type(ts == c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_
1397+
assert_type(ts == c_np_ndarray_dt64, np_ndarray_bool),
1398+
np_1darray[np.bool],
13981399
)
13991400
ne1_arr = check(
14001401
assert_type(ts != c_np_ndarray_dt64, np_ndarray_bool), np.ndarray, np.bool_
14011402
)
14021403
assert (eq1_arr != ne1_arr).all()
14031404
# TODO: the following should be 2D-arrays but it doesn't work in mypy
1404-
eq1_arr = check(
1405-
assert_type(ts == c_np_2darray_dt64, np_ndarray_bool), np_ndarray_bool
1405+
1406+
eq2_arr = check(
1407+
assert_type(
1408+
ts == c_np_2darray_dt64, np.ndarray[tuple[int, int], np.dtype[np.bool]]
1409+
),
1410+
np_ndarray_bool,
14061411
)
1407-
ne1_arr = check(
1408-
assert_type(ts != c_np_2darray_dt64, np_ndarray_bool), np_ndarray_bool
1412+
ne2_arr = check(
1413+
assert_type(
1414+
ts != c_np_2darray_dt64, np.ndarray[tuple[int, int], np.dtype[np.bool]]
1415+
),
1416+
np_ndarray_bool,
14091417
)
1410-
assert (eq1_arr != ne1_arr).all()
1418+
assert (eq2_arr != ne2_arr).all()
14111419

14121420
eq_s = check(
14131421
assert_type(ts == c_series_timestamp, "pd.Series[bool]"), pd.Series, np.bool_

0 commit comments

Comments
 (0)