Skip to content

Commit 3e0c874

Browse files
committed
♻️ replace the ComparisonOp* protocol hack with type-check-only mixins
1 parent 88e38a2 commit 3e0c874

File tree

2 files changed

+79
-31
lines changed

2 files changed

+79
-31
lines changed

src/numpy-stubs/__init__.pyi

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ from ._typing import (
263263
_TD64Like_co,
264264
_VoidDTypeLike,
265265
)
266-
from ._typing._callable import _ComparisonOpGE, _ComparisonOpGT, _ComparisonOpLE, _ComparisonOpLT
267266
from ._typing._char_codes import (
268267
_BoolCodes,
269268
_BytesCodes,
@@ -924,6 +923,22 @@ class _HasDateAttributes(Protocol):
924923
@property
925924
def year(self) -> int: ...
926925

926+
@type_check_only
927+
class _CanLT(Protocol):
928+
def __lt__(self, x: Any, /) -> Any: ...
929+
930+
@type_check_only
931+
class _CanLE(Protocol):
932+
def __le__(self, x: Any, /) -> Any: ...
933+
934+
@type_check_only
935+
class _CanGT(Protocol):
936+
def __gt__(self, x: Any, /) -> Any: ...
937+
938+
@type_check_only
939+
class _CanGE(Protocol):
940+
def __ge__(self, x: Any, /) -> Any: ...
941+
927942
###
928943
# Mixins (for internal use only)
929944

@@ -949,6 +964,42 @@ class _IntegralMixin(_RealMixin):
949964
def denominator(self) -> L[1]: ...
950965
def is_integer(self, /) -> L[True]: ...
951966

967+
_ScalarLikeT_contra = TypeVar("_ScalarLikeT_contra", contravariant=True)
968+
_ArrayLikeT_contra = TypeVar("_ArrayLikeT_contra", contravariant=True)
969+
970+
@type_check_only
971+
class _NumericComparisonMixin(Generic[_ScalarLikeT_contra, _ArrayLikeT_contra]):
972+
@overload
973+
def __lt__(self, x: _ScalarLikeT_contra, /) -> bool_: ...
974+
@overload
975+
def __lt__(self, x: _ArrayLikeT_contra | _NestedSequence[_CanGT], /) -> NDArray[bool_]: ...
976+
@overload
977+
def __lt__(self, x: _CanGT, /) -> bool_: ...
978+
979+
#
980+
@overload
981+
def __le__(self, x: _ScalarLikeT_contra, /) -> bool_: ...
982+
@overload
983+
def __le__(self, x: _ArrayLikeT_contra | _NestedSequence[_CanGE], /) -> NDArray[bool_]: ...
984+
@overload
985+
def __le__(self, x: _CanGE, /) -> bool_: ...
986+
987+
#
988+
@overload
989+
def __gt__(self, x: _ScalarLikeT_contra, /) -> bool_: ...
990+
@overload
991+
def __gt__(self, x: _ArrayLikeT_contra | _NestedSequence[_CanLT], /) -> NDArray[bool_]: ...
992+
@overload
993+
def __gt__(self, x: _CanLT, /) -> bool_: ...
994+
995+
#
996+
@overload
997+
def __ge__(self, x: _ScalarLikeT_contra, /) -> bool_: ...
998+
@overload
999+
def __ge__(self, x: _ArrayLikeT_contra | _NestedSequence[_CanLE], /) -> NDArray[bool_]: ...
1000+
@overload
1001+
def __ge__(self, x: _CanLE, /) -> bool_: ...
1002+
9521003
###
9531004
# NumType only: Does not exist at runtime!
9541005

@@ -4133,7 +4184,11 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
41334184
@property
41344185
def dtype(self) -> dtype[Self]: ...
41354186

4136-
class number(generic[_NumberItemT_co], Generic[_NBitT, _NumberItemT_co]):
4187+
class number(
4188+
_NumericComparisonMixin[_NumberLike_co, _ArrayLikeNumber_co],
4189+
generic[_NumberItemT_co],
4190+
Generic[_NBitT, _NumberItemT_co],
4191+
):
41374192
@abc.abstractmethod
41384193
def __init__(self, value: _NumberItemT_co, /) -> None: ...
41394194

@@ -4159,12 +4214,11 @@ class number(generic[_NumberItemT_co], Generic[_NBitT, _NumberItemT_co]):
41594214
def __floordiv__(self: number[Any, float], x: Any, /) -> floating | integer: ...
41604215
def __rfloordiv__(self: number[Any, float], x: Any, /) -> floating | integer: ...
41614216

4162-
__lt__: _ComparisonOpLT[_NumberLike_co, _ArrayLikeNumber_co]
4163-
__le__: _ComparisonOpLE[_NumberLike_co, _ArrayLikeNumber_co]
4164-
__gt__: _ComparisonOpGT[_NumberLike_co, _ArrayLikeNumber_co]
4165-
__ge__: _ComparisonOpGE[_NumberLike_co, _ArrayLikeNumber_co]
4166-
4167-
class bool(generic[_BoolItemT_co], Generic[_BoolItemT_co]):
4217+
class bool(
4218+
_NumericComparisonMixin[_NumberLike_co, _ArrayLikeNumber_co],
4219+
generic[_BoolItemT_co],
4220+
Generic[_BoolItemT_co],
4221+
):
41684222
@property
41694223
def itemsize(self) -> L[1]: ...
41704224
@property
@@ -4185,13 +4239,13 @@ class bool(generic[_BoolItemT_co], Generic[_BoolItemT_co]):
41854239
#
41864240
def __hash__(self, /) -> int: ...
41874241

4188-
#
4189-
def __bool__(self, /) -> _BoolItemT_co: ...
4190-
41914242
#
41924243
@deprecated("In future, it will be an error for 'np.bool' scalars to be interpreted as an index")
41934244
def __index__(self, /) -> L[0, 1]: ...
41944245

4246+
#
4247+
def __bool__(self, /) -> _BoolItemT_co: ...
4248+
41954249
#
41964250
@overload
41974251
def __int__(self: bool_[L[False]], /) -> L[0]: ...
@@ -4495,12 +4549,6 @@ class bool(generic[_BoolItemT_co], Generic[_BoolItemT_co]):
44954549
@overload
44964550
def __ror__(self, x: int, /) -> intp | bool_: ...
44974551

4498-
#
4499-
__lt__: _ComparisonOpLT[_NumberLike_co, _ArrayLikeNumber_co]
4500-
__le__: _ComparisonOpLE[_NumberLike_co, _ArrayLikeNumber_co]
4501-
__gt__: _ComparisonOpGT[_NumberLike_co, _ArrayLikeNumber_co]
4502-
__ge__: _ComparisonOpGE[_NumberLike_co, _ArrayLikeNumber_co]
4503-
45044552
# NOTE: The `object_` constructor returns the passed object, so instances with type
45054553
# `object_` cannot exists (at runtime).
45064554
# NOTE: Because mypy has some long-standing bugs related to `__new__`, `object_` can't
@@ -6402,7 +6450,12 @@ class complex128(complexfloating[_64Bit], complex):
64026450
@override
64036451
def conjugate(self) -> Self: ...
64046452

6405-
class timedelta64(_IntegralMixin, generic[_TD64ItemT_co], Generic[_TD64ItemT_co]):
6453+
class timedelta64(
6454+
_NumericComparisonMixin[_TD64Like_co, _ArrayLikeDT64_co],
6455+
_IntegralMixin,
6456+
generic[_TD64ItemT_co],
6457+
Generic[_TD64ItemT_co],
6458+
):
64066459
@property
64076460
def itemsize(self) -> L[8]: ...
64086461
@property
@@ -6582,12 +6635,12 @@ class timedelta64(_IntegralMixin, generic[_TD64ItemT_co], Generic[_TD64ItemT_co]
65826635
@overload
65836636
def __rdivmod__(self: timedelta64[dt.timedelta], x: dt.timedelta, /) -> tuple[int, dt.timedelta]: ...
65846637

6585-
__lt__: _ComparisonOpLT[_TD64Like_co, _ArrayLikeTD64_co]
6586-
__le__: _ComparisonOpLE[_TD64Like_co, _ArrayLikeTD64_co]
6587-
__gt__: _ComparisonOpGT[_TD64Like_co, _ArrayLikeTD64_co]
6588-
__ge__: _ComparisonOpGE[_TD64Like_co, _ArrayLikeTD64_co]
6589-
6590-
class datetime64(_RealMixin, generic[_DT64ItemT_co], Generic[_DT64ItemT_co]):
6638+
class datetime64(
6639+
_RealMixin,
6640+
_NumericComparisonMixin[datetime64, _ArrayLikeDT64_co],
6641+
generic[_DT64ItemT_co],
6642+
Generic[_DT64ItemT_co],
6643+
):
65916644
@property
65926645
def itemsize(self) -> L[8]: ...
65936646
@property
@@ -6673,11 +6726,6 @@ class datetime64(_RealMixin, generic[_DT64ItemT_co], Generic[_DT64ItemT_co]):
66736726
#
66746727
def __rsub__(self: datetime64[_AnyDate], x: _AnyDate, /) -> dt.timedelta: ...
66756728

6676-
__lt__: _ComparisonOpLT[datetime64, _ArrayLikeDT64_co]
6677-
__le__: _ComparisonOpLE[datetime64, _ArrayLikeDT64_co]
6678-
__gt__: _ComparisonOpGT[datetime64, _ArrayLikeDT64_co]
6679-
__ge__: _ComparisonOpGE[datetime64, _ArrayLikeDT64_co]
6680-
66816729
@final
66826730
class flexible(_RealMixin, generic[_FlexItemT_co], Generic[_FlexItemT_co]): # type: ignore[misc]
66836731
@abc.abstractmethod

src/numpy-stubs/_typing/_callable.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ class _ComparisonOpGE(Protocol[_T1_contra, _T2_contra]):
7070
@overload
7171
def __call__(self, x: _T2_contra, /) -> NDArray[np.bool]: ...
7272
@overload
73-
def __call__(self, x: _NestedSequence[_CanGT], /) -> NDArray[np.bool]: ...
73+
def __call__(self, x: _NestedSequence[_CanLE], /) -> NDArray[np.bool]: ...
7474
@overload
75-
def __call__(self, x: _CanGT, /) -> np.bool: ...
75+
def __call__(self, x: _CanLE, /) -> np.bool: ...

0 commit comments

Comments
 (0)