Skip to content

Commit d0ede0d

Browse files
committed
✨ improve __eq__ and __abs__ of numpy.generic
1 parent 1560df9 commit d0ede0d

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

src/numpy-stubs/__init__.pyi

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ _AnyShapeT = TypeVar(
552552
tuple[int, int, int, int, int, int, int, int], # 8-d
553553
tuple[int, ...], # N-d
554554
)
555-
_AnyNBitInexact = TypeVar("_AnyNBitInexact", _16Bit, _32Bit, _64Bit, _NBitLongDouble)
556555
_AnyTD64Item = TypeVar("_AnyTD64Item", dt.timedelta, int, None, dt.timedelta | int | None)
557556
_AnyDT64Arg = TypeVar("_AnyDT64Arg", dt.datetime, dt.date, None)
558557
_AnyDate = TypeVar("_AnyDate", dt.date, dt.datetime)
@@ -2136,12 +2135,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21362135
@overload # ?-d
21372136
def __iter__(self, /) -> Iterator[Any]: ...
21382137

2139-
# The last overload is for catching recursive objects whose
2140-
# nesting is too deep.
2141-
# The first overload is for catching `bytes` (as they are a subtype of
2142-
# `Sequence[int]`) and `str`. As `str` is a recursive sequence of
2143-
# strings, it will pass through the final overload otherwise
2138+
#
2139+
@overload # type: ignore[override]
2140+
def __eq__(self, other: _ScalarLike_co | ndarray[_ShapeT_co, dtype[Any]], /) -> ndarray[_ShapeT_co, dtype[bool_]]: ...
2141+
@overload
2142+
def __eq__(self, other: object, /) -> NDArray[bool_]: ... # pyright: ignore[reportIncompatibleMethodOverride]
2143+
2144+
#
2145+
@overload # type: ignore[override]
2146+
def __ne__(self, other: _ScalarLike_co | ndarray[_ShapeT_co, dtype[Any]], /) -> ndarray[_ShapeT_co, dtype[bool_]]: ...
2147+
@overload
2148+
def __ne__(self, other: object, /) -> NDArray[bool_]: ... # pyright: ignore[reportIncompatibleMethodOverride]
21442149

2150+
#
21452151
@overload
21462152
def __lt__(self: _ArrayComplex_co, other: _ArrayLikeNumber_co, /) -> NDArray[bool_]: ...
21472153
@overload
@@ -2191,13 +2197,16 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21912197

21922198
# Unary ops
21932199

2194-
@overload
2195-
def __abs__(
2196-
self: ndarray[_ShapeT, dtype[complexfloating[_AnyNBitInexact]]],
2197-
/,
2198-
) -> ndarray[_ShapeT, dtype[floating[_AnyNBitInexact]]]: ...
21992200
@overload
22002201
def __abs__(self: _RealArrayT, /) -> _RealArrayT: ...
2202+
@overload
2203+
def __abs__(self: ndarray[_ShapeT, dtype[complex64]], /) -> ndarray[_ShapeT, dtype[float32]]: ... # type: ignore[overload-overlap]
2204+
@overload
2205+
def __abs__(self: ndarray[_ShapeT, dtype[complex128]], /) -> ndarray[_ShapeT, dtype[float64]]: ...
2206+
@overload
2207+
def __abs__(self: ndarray[_ShapeT, dtype[clongdouble]], /) -> ndarray[_ShapeT, dtype[longdouble]]: ...
2208+
@overload
2209+
def __abs__(self: ndarray[_ShapeT, dtype[inexact]], /) -> ndarray[_ShapeT, dtype[floating]]: ...
22012210

22022211
#
22032212
def __invert__(self: _IntegralArrayT, /) -> _IntegralArrayT: ... # noqa: PYI019
@@ -3861,6 +3870,26 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
38613870
@abc.abstractmethod
38623871
def __init__(self, /, *args: Any, **kwargs: Any) -> None: ...
38633872

3873+
#
3874+
@overload
3875+
def __eq__(self, other: _ScalarLike_co, /) -> bool_: ...
3876+
@overload
3877+
def __eq__(self, other: ndarray[_ShapeT, dtype[Any]], /) -> ndarray[_ShapeT, dtype[bool_]]: ...
3878+
@overload
3879+
def __eq__(self, other: _NestedSequence[ArrayLike], /) -> NDArray[bool_]: ...
3880+
@overload
3881+
def __eq__(self, other: object, /) -> Any: ...
3882+
3883+
#
3884+
@overload
3885+
def __ne__(self, other: _ScalarLike_co, /) -> bool_: ...
3886+
@overload
3887+
def __ne__(self, other: ndarray[_ShapeT, dtype[Any]], /) -> ndarray[_ShapeT, dtype[bool_]]: ...
3888+
@overload
3889+
def __ne__(self, other: _NestedSequence[ArrayLike], /) -> NDArray[bool_]: ...
3890+
@overload
3891+
def __ne__(self, other: object, /) -> Any: ...
3892+
38643893
#
38653894
@overload
38663895
def __array__(self, dtype: None = None, /) -> ndarray[tuple[()], dtype[Self]]: ...

0 commit comments

Comments
 (0)