Skip to content

Commit d7978f6

Browse files
authored
Merge pull request #273 from numpy/__eq__
✨ improved `numpy.generic.__eq__`
2 parents 1560df9 + ce453a5 commit d7978f6

File tree

3 files changed

+353
-15
lines changed

3 files changed

+353
-15
lines changed

src/numpy-stubs/__init__.pyi

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ _AnyShapeT = TypeVar(
552552
tuple[int, int, int, int, int, int, int, int], # 8-d
553553
tuple[int, ...], # N-d
554554
)
555-
_AnyNBitInexact = TypeVar("_AnyNBitInexact", _16Bit, _32Bit, _64Bit, _NBitLongDouble)
556555
_AnyTD64Item = TypeVar("_AnyTD64Item", dt.timedelta, int, None, dt.timedelta | int | None)
557556
_AnyDT64Arg = TypeVar("_AnyDT64Arg", dt.datetime, dt.date, None)
558557
_AnyDate = TypeVar("_AnyDate", dt.date, dt.datetime)
@@ -2136,12 +2135,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21362135
@overload # ?-d
21372136
def __iter__(self, /) -> Iterator[Any]: ...
21382137

2139-
# The last overload is for catching recursive objects whose
2140-
# nesting is too deep.
2141-
# The first overload is for catching `bytes` (as they are a subtype of
2142-
# `Sequence[int]`) and `str`. As `str` is a recursive sequence of
2143-
# strings, it will pass through the final overload otherwise
2138+
#
2139+
@overload # type: ignore[override]
2140+
def __eq__(self, other: _ScalarLike_co | ndarray[_ShapeT_co, dtype[Any]], /) -> ndarray[_ShapeT_co, dtype[bool_]]: ...
2141+
@overload
2142+
def __eq__(self, other: object, /) -> NDArray[bool_]: ... # pyright: ignore[reportIncompatibleMethodOverride]
2143+
2144+
#
2145+
@overload # type: ignore[override]
2146+
def __ne__(self, other: _ScalarLike_co | ndarray[_ShapeT_co, dtype[Any]], /) -> ndarray[_ShapeT_co, dtype[bool_]]: ...
2147+
@overload
2148+
def __ne__(self, other: object, /) -> NDArray[bool_]: ... # pyright: ignore[reportIncompatibleMethodOverride]
21442149

2150+
#
21452151
@overload
21462152
def __lt__(self: _ArrayComplex_co, other: _ArrayLikeNumber_co, /) -> NDArray[bool_]: ...
21472153
@overload
@@ -2192,10 +2198,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21922198
# Unary ops
21932199

21942200
@overload
2195-
def __abs__(
2196-
self: ndarray[_ShapeT, dtype[complexfloating[_AnyNBitInexact]]],
2197-
/,
2198-
) -> ndarray[_ShapeT, dtype[floating[_AnyNBitInexact]]]: ...
2201+
def __abs__(self: ndarray[_ShapeT, dtype[complexfloating[_NBitT]]], /) -> ndarray[_ShapeT, dtype[floating[_NBitT]]]: ...
21992202
@overload
22002203
def __abs__(self: _RealArrayT, /) -> _RealArrayT: ...
22012204

@@ -3861,6 +3864,26 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
38613864
@abc.abstractmethod
38623865
def __init__(self, /, *args: Any, **kwargs: Any) -> None: ...
38633866

3867+
#
3868+
@overload
3869+
def __eq__(self, other: _ScalarLike_co, /) -> bool_: ...
3870+
@overload
3871+
def __eq__(self, other: ndarray[_ShapeT, dtype[Any]], /) -> ndarray[_ShapeT, dtype[bool_]]: ...
3872+
@overload
3873+
def __eq__(self, other: _NestedSequence[ArrayLike], /) -> NDArray[bool_]: ...
3874+
@overload
3875+
def __eq__(self, other: object, /) -> Any: ...
3876+
3877+
#
3878+
@overload
3879+
def __ne__(self, other: _ScalarLike_co, /) -> bool_: ...
3880+
@overload
3881+
def __ne__(self, other: ndarray[_ShapeT, dtype[Any]], /) -> ndarray[_ShapeT, dtype[bool_]]: ...
3882+
@overload
3883+
def __ne__(self, other: _NestedSequence[ArrayLike], /) -> NDArray[bool_]: ...
3884+
@overload
3885+
def __ne__(self, other: object, /) -> Any: ...
3886+
38643887
#
38653888
@overload
38663889
def __array__(self, dtype: None = None, /) -> ndarray[tuple[()], dtype[Self]]: ...

test/generate_scalar_binops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333
"<=": op.__le__,
3434
">=": op.__ge__,
3535
">": op.__gt__,
36-
# TODO(jorenham): these currently all return `Any`; fix this
37-
# "==": op.__eq__,
36+
"==": op.__eq__,
3837
}
3938
NAMES = {
4039
# builtins (key length > 1)
@@ -132,8 +131,12 @@ def _assert_stmt(op: str, lhs: str, rhs: str, /) -> str | None:
132131
"# pyright: ignore[reportOperatorIssue]",
133132
))
134133

135-
expr_type = _sctype_expr(val_out.dtype)
136-
return f"assert_type({expr_eval}, {expr_type})"
134+
expr_type = (
135+
_sctype_expr(val_out.dtype)
136+
if isinstance(val_out, np.generic)
137+
else type(val_out).__qualname__
138+
)
139+
return f"assert_type({expr_eval}, {expr_type})" if expr_type != "bool" else None
137140

138141

139142
def _gen_imports() -> Generator[str]:

0 commit comments

Comments
 (0)