Skip to content

Commit 8de6607

Browse files
committed
✅ fix failing ndarray.__add__ tests for builtin scalars
1 parent 6b5bd94 commit 8de6607

File tree

1 file changed

+38
-22
lines changed

1 file changed

+38
-22
lines changed

src/numpy-stubs/__init__.pyi

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,6 @@ _AnyShapeT = TypeVar(
623623
tuple[int, int, int, int],
624624
tuple[int, ...],
625625
)
626-
_AnyItemT = TypeVar("_AnyItemT", bool, int, float, complex, bytes, str)
627626

628627
###
629628
# Type parameters (for internal use only)
@@ -775,6 +774,7 @@ _JustFloating: TypeAlias = _nt.Just[floating]
775774
_JustComplexFloating: TypeAlias = _nt.Just[complexfloating]
776775
_JustInexact: TypeAlias = _nt.Just[inexact]
777776
_JustNumber: TypeAlias = _nt.Just[number]
777+
_JustBuiltinScalar: TypeAlias = int | _nt.JustFloat | _nt.JustComplex | _nt.JustBytes | _nt.JustStr
778778

779779
_AbstractInexact: TypeAlias = _JustInexact | _JustFloating | _JustComplexFloating
780780
_AbstractInteger: TypeAlias = _JustInteger | _JustSignedInteger | _JustUnsignedInteger
@@ -2141,17 +2141,21 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21412141
@overload
21422142
def __add__(self: NDArray[_ScalarT], x: _nt.CastsWith[_ScalarT, _ScalarOutT], /) -> NDArray[_ScalarOutT]: ... # type: ignore[overload-overlap]
21432143
@overload
2144+
def __add__(self: _nt.CastsWithBuiltin[_T, _ScalarOutT], x: _nt.SequenceND[_T], /) -> NDArray[_ScalarOutT]: ...
2145+
@overload
2146+
def __add__(self: _nt.CastsWithInt[_ScalarOutT], x: _nt.SequenceND[_nt.JustInt], /) -> NDArray[_ScalarOutT]: ...
2147+
@overload
2148+
def __add__(self: _nt.CastsWithFloat[_ScalarOutT], x: _nt.SequenceND[_nt.JustFloat], /) -> NDArray[_ScalarOutT]: ...
2149+
@overload
21442150
def __add__(
2145-
self: NDArray[generic[_AnyItemT]],
2146-
x: _nt.SequenceND[_AnyItemT],
2147-
/,
2148-
) -> ndarray[tuple[int, ...], _DTypeT_co]: ...
2151+
self: _nt.CastsWithComplex[_ScalarOutT], x: _nt.SequenceND[_nt.JustComplex], /
2152+
) -> NDArray[_ScalarOutT]: ...
21492153
@overload
21502154
def __add__(self: NDArray[datetime64], x: _nt.CoTimeDelta_nd, /) -> NDArray[datetime64]: ...
21512155
@overload
21522156
def __add__(self: NDArray[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> NDArray[datetime64]: ...
21532157
@overload
2154-
def __add__(self: NDArray[object_], x: object, /) -> NDArray[object_]: ...
2158+
def __add__(self: NDArray[object_], x: object, /) -> NDArray[object_]: ... # type: ignore[overload-cannot-match]
21552159
@overload
21562160
def __add__(self: NDArray[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ...
21572161
@overload
@@ -2163,17 +2167,23 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21632167
@overload
21642168
def __radd__(self: NDArray[_ScalarT], x: _nt.CastsWith[_ScalarT, _ScalarOutT], /) -> NDArray[_ScalarOutT]: ... # type: ignore[overload-overlap]
21652169
@overload
2170+
def __radd__(self: _nt.CastsWithBuiltin[_T, _ScalarOutT], x: _nt.SequenceND[_T], /) -> NDArray[_ScalarOutT]: ...
2171+
@overload
2172+
def __radd__(self: _nt.CastsWithInt[_ScalarOutT], x: _nt.SequenceND[_nt.JustInt], /) -> NDArray[_ScalarOutT]: ...
2173+
@overload
21662174
def __radd__(
2167-
self: NDArray[generic[_AnyItemT]],
2168-
x: _nt.SequenceND[_AnyItemT],
2169-
/,
2170-
) -> ndarray[tuple[int, ...], _DTypeT_co]: ...
2175+
self: _nt.CastsWithFloat[_ScalarOutT], x: _nt.SequenceND[_nt.JustFloat], /
2176+
) -> NDArray[_ScalarOutT]: ...
2177+
@overload
2178+
def __radd__(
2179+
self: _nt.CastsWithComplex[_ScalarOutT], x: _nt.SequenceND[_nt.JustComplex], /
2180+
) -> NDArray[_ScalarOutT]: ...
21712181
@overload
21722182
def __radd__(self: NDArray[datetime64], x: _nt.CoTimeDelta_nd, /) -> NDArray[datetime64]: ...
21732183
@overload
21742184
def __radd__(self: NDArray[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> NDArray[datetime64]: ...
21752185
@overload
2176-
def __radd__(self: NDArray[object_], x: object, /) -> NDArray[object_]: ...
2186+
def __radd__(self: NDArray[object_], x: object, /) -> NDArray[object_]: ... # type: ignore[overload-cannot-match]
21772187
@overload
21782188
def __radd__(self: NDArray[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ...
21792189
@overload
@@ -2183,11 +2193,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21832193
@overload # type: ignore[misc]
21842194
def __iadd__(self: NDArray[_ScalarT], x: _nt.Casts[_ScalarT], /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
21852195
@overload
2186-
def __iadd__(
2187-
self: NDArray[generic[_AnyItemT]],
2188-
x: _nt.SequenceND[_AnyItemT],
2189-
/,
2190-
) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
2196+
def __iadd__(self: _nt.CastsWithBuiltin[_T, Any], x: _nt.SequenceND[_T], /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
21912197
@overload
21922198
def __iadd__(self: NDArray[datetime64], x: _nt.ToTimeDelta_nd, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
21932199
@overload
@@ -4156,7 +4162,7 @@ class bool_(generic[_BoolItemT_co], Generic[_BoolItemT_co]):
41564162
@type_check_only
41574163
def __nep50__(self, below: _nt.co_number | timedelta64, above: Never, /) -> bool_: ...
41584164
@type_check_only
4159-
def __nep50_bool__(self, /) -> bool_: ...
4165+
def __nep50_builtin__(self, /) -> tuple[py_bool, bool_]: ...
41604166
@type_check_only
41614167
def __nep50_int__(self, /) -> intp: ...
41624168
@type_check_only
@@ -4604,9 +4610,8 @@ class number(
46044610
generic[_NumberItemT_co],
46054611
Generic[_BitT, _NumberItemT_co],
46064612
):
4607-
@final
46084613
@type_check_only
4609-
def __nep50_bool__(self, /) -> Self: ...
4614+
def __nep50_builtin__(self, /) -> tuple[int, Self]: ...
46104615
@final
46114616
@type_check_only
46124617
def __nep50_int__(self, /) -> Self: ...
@@ -5463,6 +5468,8 @@ complex256 = clongdouble
54635468
class object_(_RealMixin, generic[Any]):
54645469
@type_check_only
54655470
def __nep50__(self, below: object_, above: _nt.co_number | character, /) -> object_: ...
5471+
@type_check_only
5472+
def __nep50_builtin__(self, /) -> tuple[_JustBuiltinScalar, object_]: ...
54665473

54675474
#
54685475
@overload
@@ -5500,7 +5507,9 @@ class character(flexible[_CharacterItemT_co], Generic[_CharacterItemT_co]): # t
55005507

55015508
class bytes_(character[bytes], bytes): # type: ignore[misc]
55025509
@type_check_only
5503-
def __nep50__(self, below: bytes_, above: bytes_, /) -> bytes_: ...
5510+
def __nep50__(self, below: bytes_ | object_, above: Never, /) -> bytes_: ...
5511+
@type_check_only
5512+
def __nep50_builtin__(self, /) -> tuple[_nt.JustBytes, bytes_]: ...
55045513

55055514
#
55065515
@overload
@@ -5514,7 +5523,9 @@ class bytes_(character[bytes], bytes): # type: ignore[misc]
55145523

55155524
class str_(character[str], str): # type: ignore[misc]
55165525
@type_check_only
5517-
def __nep50__(self, below: str | str_, from_: str_, /) -> str_: ...
5526+
def __nep50__(self, below: str_ | object_, above: Never, /) -> str_: ...
5527+
@type_check_only
5528+
def __nep50_builtin__(self, /) -> tuple[_nt.JustStr, str_]: ...
55185529

55195530
#
55205531
@overload
@@ -5544,7 +5555,10 @@ class void(flexible[bytes | tuple[Any, ...]]): # type: ignore[misc] # pyright:
55445555
def setfield(self, val: ArrayLike, dtype: DTypeLike, offset: int = ...) -> None: ...
55455556

55465557
class datetime64(
5547-
_RealMixin, _CmpOpMixin[datetime64, _ArrayLikeDT64_co], generic[_DT64ItemT_co], Generic[_DT64ItemT_co]
5558+
_RealMixin,
5559+
_CmpOpMixin[datetime64, _ArrayLikeDT64_co],
5560+
generic[_DT64ItemT_co],
5561+
Generic[_DT64ItemT_co],
55485562
):
55495563
@property
55505564
@override
@@ -5681,6 +5695,8 @@ class timedelta64(
56815695
):
56825696
@type_check_only
56835697
def __nep50__(self, below: timedelta64, above: _nt.co_integer, /) -> timedelta64: ...
5698+
@type_check_only
5699+
def __nep50_builtin__(self, /) -> tuple[int, timedelta64]: ...
56845700

56855701
#
56865702
@property

0 commit comments

Comments
 (0)