Skip to content

Commit 29d21f4

Browse files
committed
🎨 DRY cleanup of np.lib._user_array_impl
1 parent ffc0338 commit 29d21f4

File tree

1 file changed

+44
-52
lines changed

1 file changed

+44
-52
lines changed

‎src/numpy-stubs/lib/_user_array_impl.pyi

Lines changed: 44 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from _typeshed import Incomplete
22
from types import EllipsisType
3-
from typing import Any, Generic, Self, SupportsIndex, TypeAlias, TypeVar, overload
4-
from typing_extensions import deprecated, override
3+
from typing import Any, Generic, Self, SupportsIndex, TypeAlias, overload
4+
from typing_extensions import TypeAliasType, TypeVar, deprecated, override
55

66
import _numtype as _nt
77
import numpy as np
@@ -10,23 +10,20 @@ from numpy._typing import _ArrayLike, _ArrayLikeBool_co, _ArrayLikeInt_co, _DTyp
1010

1111
###
1212

13-
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
14-
_ShapeT = TypeVar("_ShapeT", bound=_nt.Shape)
13+
_ScalarT = TypeVar("_ScalarT", bound=np.generic, default=Any)
14+
_ShapeT = TypeVar("_ShapeT", bound=_nt.Shape, default=Any)
1515
_ShapeT_co = TypeVar("_ShapeT_co", bound=_nt.Shape, default=Any, covariant=True)
16-
_DTypeT = TypeVar("_DTypeT", bound=np.dtype)
16+
_DTypeT = TypeVar("_DTypeT", bound=np.dtype, default=np.dtype)
1717
_DTypeT_co = TypeVar("_DTypeT_co", bound=np.dtype, default=np.dtype, covariant=True)
1818

19-
_BoolArrayT = TypeVar("_BoolArrayT", bound=container[Any, np.dtype[np.bool]])
20-
_IntegralArrayT = TypeVar("_IntegralArrayT", bound=container[Any, np.dtype[np.bool | np.integer | np.object_]])
21-
_RealContainerT = TypeVar(
22-
"_RealContainerT",
23-
bound=container[Any, np.dtype[np.bool | np.integer | np.floating | np.timedelta64 | np.object_]],
24-
)
25-
_NumericContainerT = TypeVar(
26-
"_NumericContainerT", bound=container[Any, np.dtype[np.number | np.timedelta64 | np.object_]]
27-
)
19+
_Container = TypeAliasType("_Container", container[_ShapeT, np.dtype[_ScalarT]], type_params=(_ScalarT, _ShapeT))
2820

29-
_ArrayInt_co: TypeAlias = npt.NDArray[np.integer | np.bool]
21+
_BoolArrayT = TypeVar("_BoolArrayT", bound=_Container[np.bool_])
22+
_IntegralArrayT = TypeVar("_IntegralArrayT", bound=_Container[_nt.co_integer | np.object_])
23+
_RealContainerT = TypeVar("_RealContainerT", bound=_Container[_nt.co_timedelta | np.floating | np.object_])
24+
_NumericContainerT = TypeVar("_NumericContainerT", bound=_Container[np.number | np.timedelta64 | np.object_])
25+
26+
_ArrayInt_co: TypeAlias = _nt.Array[np.integer | np.bool]
3027

3128
_ToIndexSlice: TypeAlias = slice | EllipsisType | _ArrayInt_co | None
3229
_ToIndexSlices: TypeAlias = _ToIndexSlice | tuple[_ToIndexSlice, ...]
@@ -48,22 +45,28 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
4845
) -> None: ...
4946
@overload
5047
def __init__(
51-
self: container[Any, np.dtype[_ScalarT]],
48+
self: _Container[_ScalarT],
5249
/,
5350
data: _ArrayLike[_ScalarT],
5451
dtype: None = None,
5552
copy: bool = True,
5653
) -> None: ...
5754
@overload
5855
def __init__(
59-
self: container[Any, np.dtype[_ScalarT]],
56+
self: _Container[_ScalarT],
6057
/,
6158
data: npt.ArrayLike,
6259
dtype: _DTypeLike[_ScalarT],
6360
copy: bool = True,
6461
) -> None: ...
6562
@overload
66-
def __init__(self, /, data: npt.ArrayLike, dtype: npt.DTypeLike | None = None, copy: bool = True) -> None: ...
63+
def __init__(
64+
self,
65+
/,
66+
data: npt.ArrayLike,
67+
dtype: npt.DTypeLike | None = None,
68+
copy: bool = True,
69+
) -> None: ...
6770

6871
#
6972
def __complex__(self, /) -> complex: ...
@@ -74,15 +77,15 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
7477

7578
#
7679
@override
77-
def __eq__(self, other: object, /) -> container[_ShapeT_co, np.dtype[np.bool]]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
80+
def __eq__(self, other: object, /) -> _Container[np.bool, _ShapeT_co]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
7881
@override
79-
def __ne__(self, other: object, /) -> container[_ShapeT_co, np.dtype[np.bool]]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
82+
def __ne__(self, other: object, /) -> _Container[np.bool, _ShapeT_co]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
8083

8184
#
82-
def __lt__(self, other: npt.ArrayLike, /) -> container[_ShapeT_co, np.dtype[np.bool]]: ...
83-
def __le__(self, other: npt.ArrayLike, /) -> container[_ShapeT_co, np.dtype[np.bool]]: ...
84-
def __gt__(self, other: npt.ArrayLike, /) -> container[_ShapeT_co, np.dtype[np.bool]]: ...
85-
def __ge__(self, other: npt.ArrayLike, /) -> container[_ShapeT_co, np.dtype[np.bool]]: ...
85+
def __lt__(self, other: npt.ArrayLike, /) -> _Container[np.bool, _ShapeT_co]: ...
86+
def __le__(self, other: npt.ArrayLike, /) -> _Container[np.bool, _ShapeT_co]: ...
87+
def __gt__(self, other: npt.ArrayLike, /) -> _Container[np.bool, _ShapeT_co]: ...
88+
def __ge__(self, other: npt.ArrayLike, /) -> _Container[np.bool, _ShapeT_co]: ...
8689

8790
#
8891
def __len__(self, /) -> int: ...
@@ -95,27 +98,23 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
9598
@overload
9699
def __getitem__(self, key: _ToIndices, /) -> Any: ...
97100
@overload
98-
def __getitem__(
99-
self: container[Any, np.dtype[np.void]], key: list[str], /
100-
) -> container[_ShapeT_co, np.dtype[np.void]]: ...
101+
def __getitem__(self: _Container[np.void], key: list[str], /) -> _Container[np.void, _ShapeT_co]: ...
101102
@overload
102-
def __getitem__(self: container[Any, np.dtype[np.void]], key: str, /) -> container[_ShapeT_co, np.dtype]: ...
103+
def __getitem__(self: _Container[np.void], key: str, /) -> _Container[Any, _ShapeT_co]: ...
103104

104105
# keep in sync with np.ndarray
105106
@overload
106107
def __setitem__(self, index: _ToIndices, value: object, /) -> None: ...
107108
@overload
108-
def __setitem__(self: container[Any, np.dtype[np.void]], key: str | list[str], value: object, /) -> None: ...
109+
def __setitem__(self: _Container[np.void], key: str | list[str], value: object, /) -> None: ...
109110

110111
# keep in sync with np.ndarray
111112
@overload
112-
def __abs__(self: container[_ShapeT, np.dtype[np.complex64]], /) -> container[_ShapeT, np.dtype[np.float32]]: ...
113+
def __abs__(self: _Container[np.complex64, _ShapeT], /) -> _Container[np.float32, _ShapeT]: ...
113114
@overload
114-
def __abs__(self: container[_ShapeT, np.dtype[np.complex128]], /) -> container[_ShapeT, np.dtype[np.float64]]: ...
115+
def __abs__(self: _Container[np.complex128, _ShapeT], /) -> _Container[np.float64, _ShapeT]: ...
115116
@overload
116-
def __abs__(
117-
self: container[_ShapeT, np.dtype[np.clongdouble]], /
118-
) -> container[_ShapeT, np.dtype[np.longdouble]]: ...
117+
def __abs__(self: _Container[np.clongdouble, _ShapeT], /) -> _Container[np.longdouble, _ShapeT]: ...
119118
@overload
120119
def __abs__(self: _RealContainerT, /) -> _RealContainerT: ...
121120

@@ -125,7 +124,6 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
125124
def __invert__(self: _IntegralArrayT, /) -> _IntegralArrayT: ... # noqa: PYI019
126125

127126
# TODO(jorenham): complete these binary ops
128-
129127
#
130128
def __add__(self, other: npt.ArrayLike, /) -> Incomplete: ...
131129
def __radd__(self, other: npt.ArrayLike, /) -> Incomplete: ...
@@ -161,22 +159,20 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
161159
def __ipow__(self, other: npt.ArrayLike, /) -> Self: ...
162160

163161
#
164-
def __lshift__(self, other: _ArrayLikeInt_co, /) -> container[Any, np.dtype[np.integer]]: ...
165-
def __rlshift__(self, other: _ArrayLikeInt_co, /) -> container[Any, np.dtype[np.integer]]: ...
162+
def __lshift__(self, other: _ArrayLikeInt_co, /) -> _Container[np.integer]: ...
163+
def __rlshift__(self, other: _ArrayLikeInt_co, /) -> _Container[np.integer]: ...
166164
def __ilshift__(self, other: _ArrayLikeInt_co, /) -> Self: ...
167165

168166
#
169-
def __rshift__(self, other: _ArrayLikeInt_co, /) -> container[Any, np.dtype[np.integer]]: ...
170-
def __rrshift__(self, other: _ArrayLikeInt_co, /) -> container[Any, np.dtype[np.integer]]: ...
167+
def __rshift__(self, other: _ArrayLikeInt_co, /) -> _Container[np.integer]: ...
168+
def __rrshift__(self, other: _ArrayLikeInt_co, /) -> _Container[np.integer]: ...
171169
def __irshift__(self, other: _ArrayLikeInt_co, /) -> Self: ...
172170

173171
#
174172
@overload
175-
def __and__(
176-
self: container[Any, np.dtype[np.bool]], other: _ArrayLikeBool_co, /
177-
) -> container[Any, np.dtype[np.bool]]: ...
173+
def __and__(self: _Container[np.bool_], other: _ArrayLikeBool_co, /) -> _Container[np.bool_]: ...
178174
@overload
179-
def __and__(self, other: _ArrayLikeInt_co, /) -> container[Any, np.dtype[np.bool | np.integer]]: ...
175+
def __and__(self, other: _ArrayLikeInt_co, /) -> _Container[_nt.co_integer]: ...
180176
__rand__ = __and__
181177
@overload
182178
def __iand__(self: _BoolArrayT, other: _ArrayLikeBool_co, /) -> _BoolArrayT: ...
@@ -185,11 +181,9 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
185181

186182
#
187183
@overload
188-
def __xor__(
189-
self: container[Any, np.dtype[np.bool]], other: _ArrayLikeBool_co, /
190-
) -> container[Any, np.dtype[np.bool]]: ...
184+
def __xor__(self: _Container[np.bool_], other: _ArrayLikeBool_co, /) -> _Container[np.bool_]: ...
191185
@overload
192-
def __xor__(self, other: _ArrayLikeInt_co, /) -> container[Any, np.dtype[np.bool | np.integer]]: ...
186+
def __xor__(self, other: _ArrayLikeInt_co, /) -> _Container[_nt.co_integer]: ...
193187
__rxor__ = __xor__
194188
@overload
195189
def __ixor__(self: _BoolArrayT, other: _ArrayLikeBool_co, /) -> _BoolArrayT: ...
@@ -198,11 +192,9 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
198192

199193
#
200194
@overload
201-
def __or__(
202-
self: container[Any, np.dtype[np.bool]], other: _ArrayLikeBool_co, /
203-
) -> container[Any, np.dtype[np.bool]]: ...
195+
def __or__(self: _Container[np.bool_], other: _ArrayLikeBool_co, /) -> _Container[np.bool_]: ...
204196
@overload
205-
def __or__(self, other: _ArrayLikeInt_co, /) -> container[Any, np.dtype[np.bool | np.integer]]: ...
197+
def __or__(self, other: _ArrayLikeInt_co, /) -> _Container[_nt.co_integer]: ...
206198
__ror__ = __or__
207199
@overload
208200
def __ior__(self: _BoolArrayT, other: _ArrayLikeBool_co, /) -> _BoolArrayT: ...
@@ -229,4 +221,4 @@ class container(Generic[_ShapeT_co, _DTypeT_co]):
229221
def tostring(self, /) -> bytes: ...
230222
def tobytes(self, /) -> bytes: ...
231223
def byteswap(self, /) -> Self: ...
232-
def astype(self, /, typecode: _DTypeLike[_ScalarT]) -> container[_ShapeT_co, np.dtype[_ScalarT]]: ...
224+
def astype(self, /, typecode: _DTypeLike[_ScalarT]) -> _Container[_ScalarT, _ShapeT_co]: ...

0 commit comments

Comments
 (0)