Skip to content

Commit 67cafa9

Browse files
authored
Merge pull request #542 from numpy/cleaner-rank-types
2 parents fb3f2ce + ad4e0cb commit 67cafa9

File tree

10 files changed

+586
-536
lines changed

10 files changed

+586
-536
lines changed

src/_numtype/@test/test_rank.pyi

Lines changed: 422 additions & 446 deletions
Large diffs are not rendered by default.

src/_numtype/__init__.pyi

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ from ._nep50 import (
6161
CastsWithInt as CastsWithInt,
6262
CastsWithScalar as CastsWithScalar,
6363
)
64+
from ._rank import (
65+
Broadcasts as Broadcasts,
66+
BroadcastsTo as BroadcastsTo,
67+
Rank as Rank,
68+
Rank0 as Rank0,
69+
Rank0N as Rank0N,
70+
Rank1 as Rank1,
71+
Rank1N as Rank1N,
72+
Rank2 as Rank2,
73+
Rank2N as Rank2N,
74+
Rank3 as Rank3,
75+
Rank3N as Rank3N,
76+
Rank4 as Rank4,
77+
Rank4N as Rank4N,
78+
_BroadcastableShape as _BroadcastableShape,
79+
)
6480
from ._scalar import (
6581
inexact32 as inexact32,
6682
inexact64 as inexact64,
@@ -104,14 +120,15 @@ from ._scalar_co import (
104120
from ._shape import (
105121
Shape as Shape,
106122
Shape0 as Shape0,
123+
Shape0N as Shape0N,
107124
Shape1 as Shape1,
108-
Shape1_ as Shape1_,
125+
Shape1N as Shape1N,
109126
Shape2 as Shape2,
110-
Shape2_ as Shape2_,
127+
Shape2N as Shape2N,
111128
Shape3 as Shape3,
112-
Shape3_ as Shape3_,
129+
Shape3N as Shape3N,
113130
Shape4 as Shape4,
114-
Shape4_ as Shape4_,
131+
Shape4N as Shape4N,
115132
)
116133

117134
###
@@ -261,10 +278,10 @@ _ToArray2_3ds: TypeAlias = CanArray3D[_ScalarT] | Sequence[_ToArray2_2ds[_Scalar
261278
# requires a lower bound on dimensionality, e.g. `_2nd` denotes `ndin` within `[2, n]`
262279
_ToArray_1nd: TypeAlias = CanLenArrayND[_ScalarT] | Sequence1ND[CanArrayND[_ScalarT]]
263280
_ToArray2_1nd: TypeAlias = CanLenArrayND[_ScalarT] | Sequence1ND[_ToT | CanArrayND[_ScalarT]]
264-
_ToArray_2nd: TypeAlias = CanLenArray[_ScalarT, Shape2_] | Sequence[_ToArray_1nd[_ScalarT]]
265-
_ToArray2_2nd: TypeAlias = CanLenArray[_ScalarT, Shape2_] | Sequence[_ToArray2_1nd[_ScalarT, _ToT]]
266-
_ToArray_3nd: TypeAlias = CanLenArray[_ScalarT, Shape3_] | Sequence[_ToArray_2nd[_ScalarT]]
267-
_ToArray2_3nd: TypeAlias = CanLenArray[_ScalarT, Shape3_] | Sequence[_ToArray2_2nd[_ScalarT, _ToT]]
281+
_ToArray_2nd: TypeAlias = CanLenArray[_ScalarT, Shape2N] | Sequence[_ToArray_1nd[_ScalarT]]
282+
_ToArray2_2nd: TypeAlias = CanLenArray[_ScalarT, Shape2N] | Sequence[_ToArray2_1nd[_ScalarT, _ToT]]
283+
_ToArray_3nd: TypeAlias = CanLenArray[_ScalarT, Shape3N] | Sequence[_ToArray_2nd[_ScalarT]]
284+
_ToArray2_3nd: TypeAlias = CanLenArray[_ScalarT, Shape3N] | Sequence[_ToArray2_2nd[_ScalarT, _ToT]]
268285

269286
###
270287
# Non-overlapping scalar- and array-like aliases for all scalar types.
@@ -696,9 +713,9 @@ ToString_3ds = TypeAliasType(
696713
_CanStringArray[Shape3, _NaT0] | Sequence[ToString_2ds[_NaT0]],
697714
type_params=(_NaT0,),
698715
)
699-
ToString_1nd = TypeAliasType("ToString_1nd", _CanStringArray[Shape1_, _NaT0], type_params=(_NaT0,))
700-
ToString_2nd = TypeAliasType("ToString_2nd", _CanStringArray[Shape2_, _NaT0], type_params=(_NaT0,))
701-
ToString_3nd = TypeAliasType("ToString_3nd", _CanStringArray[Shape3_, _NaT0], type_params=(_NaT0,))
716+
ToString_1nd = TypeAliasType("ToString_1nd", _CanStringArray[Shape1N, _NaT0], type_params=(_NaT0,))
717+
ToString_2nd = TypeAliasType("ToString_2nd", _CanStringArray[Shape2N, _NaT0], type_params=(_NaT0,))
718+
ToString_3nd = TypeAliasType("ToString_3nd", _CanStringArray[Shape3N, _NaT0], type_params=(_NaT0,))
702719

703720
# any scalar
704721
ToGeneric_nd = TypeAliasType("ToGeneric_nd", _ToArray2_nd[np.generic, _PyScalar])

src/_numtype/_rank.pyi

Lines changed: 104 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,38 @@
1-
from typing import Any, Protocol, Self, TypeAlias, final, type_check_only
1+
from typing import Any, Generic, Protocol, Self, TypeAlias, final, type_check_only
22
from typing_extensions import TypeAliasType, TypeVar
33

44
from ._shape import (
55
Shape,
66
Shape as Shape0ToN,
77
Shape0,
88
Shape1,
9-
Shape1_ as Shape1ToN,
9+
Shape1N as Shape1ToN,
1010
Shape2,
11-
Shape2_ as Shape2ToN,
11+
Shape2N as Shape2ToN,
1212
Shape3,
13-
Shape3_ as Shape3ToN,
13+
Shape3N as Shape3ToN,
1414
Shape4,
15-
Shape4_ as Shape4ToN,
15+
Shape4N as Shape4ToN,
1616
)
1717

1818
__all__ = [
19-
"Broadcastable",
19+
"Broadcasts",
20+
"BroadcastsTo",
21+
"Rank",
2022
"Rank0",
21-
"Rank0ToN",
23+
"Rank0N",
2224
"Rank1",
23-
"Rank1ToN",
25+
"Rank1N",
2426
"Rank2",
25-
"Rank2ToN",
27+
"Rank2N",
2628
"Rank3",
27-
"Rank3ToN",
29+
"Rank3N",
2830
"Rank4",
29-
"Rank4ToN",
31+
"Rank4N",
3032
]
3133

3234
###
3335

34-
_ToT = TypeVar("_ToT", bound=Shape)
35-
_ToT_contra = TypeVar("_ToT_contra", bound=Shape, contravariant=True)
36-
_FromT = TypeVar("_FromT", bound=Shape)
37-
_FromT_contra = TypeVar("_FromT_contra", bound=Shape, default=Any, contravariant=True)
38-
_RankT = TypeVar("_RankT", bound=_HasShape[Any], default=Any)
39-
_RankT_co = TypeVar("_RankT_co", default=Any, covariant=True)
40-
_ShapeT_contra = TypeVar("_ShapeT_contra", contravariant=True)
41-
_ShapeT_co = TypeVar("_ShapeT_co", covariant=True, default=_ShapeT_contra)
42-
43-
###
44-
4536
_Shape0To0: TypeAlias = Shape0
4637
_Shape0To1: TypeAlias = _Shape0To0 | Shape1
4738
_Shape0To2: TypeAlias = _Shape0To1 | Shape2
@@ -50,74 +41,134 @@ _Shape0To4: TypeAlias = _Shape0To3 | Shape4
5041

5142
###
5243

44+
_ToT = TypeVar("_ToT", bound=Shape)
45+
_FromT = TypeVar("_FromT", bound=Shape)
46+
_RankT = TypeVar("_RankT", bound=Shape, default=Any)
47+
48+
_BroadcastableShape = TypeAliasType(
49+
"_BroadcastableShape",
50+
_FromT | _CanBroadcastFrom[_FromT, _RankT],
51+
type_params=(_FromT, _RankT),
52+
)
53+
54+
BroadcastsTo = TypeAliasType(
55+
"BroadcastsTo",
56+
_HasRank[_CanBroadcastTo[_ToT, _RankT]],
57+
type_params=(_ToT, _RankT),
58+
)
59+
Broadcasts = TypeAliasType(
60+
"Broadcasts",
61+
_HasRank[_BroadcastableShape[_FromT, _RankT]],
62+
type_params=(_FromT, _RankT),
63+
)
64+
65+
###
66+
67+
_ShapeT_co = TypeVar(
68+
"_ShapeT_co",
69+
bound=Shape | _HasOwnShape | _CanBroadcastFrom | _CanBroadcastTo,
70+
covariant=True,
71+
)
72+
73+
@type_check_only
74+
class _HasShape(Protocol[_ShapeT_co]):
75+
@property
76+
def shape(self, /) -> _ShapeT_co: ...
77+
78+
_ShapeT = TypeVar("_ShapeT", bound=Shape)
79+
80+
@final
5381
@type_check_only
54-
class _CanBroadcast(Protocol[_ToT_contra, _FromT_contra, _RankT_co]):
55-
def __broadcast__(self, to: _ToT_contra, from_: _FromT_contra, /) -> _RankT_co: ...
82+
class _HasRank(Protocol[_ShapeT_co]):
83+
@property
84+
def shape(self: _HasShape[_ShapeT], /) -> _ShapeT: ...
85+
86+
_FromT_contra = TypeVar("_FromT_contra", default=Any, contravariant=True)
87+
_ToT_contra = TypeVar("_ToT_contra", bound=Shape, default=Any, contravariant=True)
88+
_RankT_co = TypeVar("_RankT_co", bound=Shape, default=Any, covariant=True)
89+
90+
@final
91+
@type_check_only
92+
class _CanBroadcastFrom(Protocol[_FromT_contra, _RankT_co]):
93+
def __broadcast_from__(self, from_: _FromT_contra, /) -> _RankT_co: ...
94+
95+
@final
96+
@type_check_only
97+
class _CanBroadcastTo(Protocol[_ToT_contra, _RankT_co]):
98+
def __broadcast_to__(self, to: _ToT_contra, /) -> _RankT_co: ...
5699

57100
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
58-
# e.g. `_HasShape[Shape2ToN, Shape0ToN]` accepts `Shape2ToN`, `Shape1ToN`, and `Shape0ToN`, but
101+
# e.g. `_HasOwnShape[Shape2ToN, Shape0ToN]` accepts `Shape2ToN`, `Shape1ToN`, and `Shape0ToN`, but
59102
# rejects `Shape3ToN` and `Shape1`. Besides brevity, it also works around several mypy bugs that
60103
# are related to "unions vs joins".
104+
105+
_OwnShapeT_contra = TypeVar("_OwnShapeT_contra", bound=Shape, default=Any, contravariant=True)
106+
_OwnShapeT_co = TypeVar("_OwnShapeT_co", bound=Shape, default=_OwnShapeT_contra, covariant=True)
107+
108+
@final
61109
@type_check_only
62-
class _HasShape(Protocol[_ShapeT_contra, _ShapeT_co]):
63-
def __shape__(self, shape: _ShapeT_contra, /) -> _ShapeT_co: ...
110+
class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
111+
def __own_shape__(self, shape: _OwnShapeT_contra, /) -> _OwnShapeT_co: ...
64112

65113
###
66114
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
67115

68-
@final
116+
_OwnShapeT = TypeVar("_OwnShapeT", bound=tuple[Any, ...], default=Any)
117+
69118
@type_check_only
70-
class Rank0(tuple[()], _HasShape[Shape0, Shape0]):
71-
def __broadcast__(self, to: Shape0ToN, from_: _HasShape[Shape0ToN, Shape0ToN], /) -> Self: ...
119+
class _BaseRank(Generic[_FromT_contra, _ToT_contra, _OwnShapeT]):
120+
def __broadcast_from__(self, from_: _FromT_contra, /) -> Self: ...
121+
def __broadcast_to__(self, to: _ToT_contra, /) -> Self: ...
122+
def __own_shape__(self, shape: _OwnShapeT, /) -> _OwnShapeT: ...
72123

73124
@type_check_only
74-
class Rank1(tuple[int], _HasShape[Shape1, Shape1]):
75-
def __broadcast__(self, to: Shape1ToN, from_: _Shape0To1 | _HasShape[Shape1ToN, Shape0ToN], /) -> Self: ...
125+
class _BaseRankM(
126+
_BaseRank[_FromT_contra | _HasOwnShape[_ToT_contra, Shape], _ToT_contra, _OwnShapeT],
127+
Generic[_FromT_contra, _ToT_contra, _OwnShapeT],
128+
): ...
76129

77130
@final
78131
@type_check_only
79-
class Rank2(tuple[int, int], _HasShape[Shape2, Shape2]):
80-
def __broadcast__(self, to: Shape2ToN, from_: _Shape0To2 | _HasShape[Shape2ToN, Shape0ToN], /) -> Self: ...
132+
class Rank0(_BaseRankM[_Shape0To0, Shape0ToN, Shape0], tuple[()]): ...
81133

82134
@final
83135
@type_check_only
84-
class Rank3(tuple[int, int, int], _HasShape[Shape3, Shape3]):
85-
def __broadcast__(self, to: Shape3ToN, from_: _Shape0To3 | _HasShape[Shape3ToN, Shape0ToN], /) -> Self: ...
136+
class Rank1(_BaseRankM[_Shape0To1, Shape1ToN, Shape1], tuple[int]): ...
86137

87138
@final
88139
@type_check_only
89-
class Rank4(tuple[int, int, int, int], _HasShape[Shape4, Shape4]):
90-
def __broadcast__(self, to: Shape4ToN, from_: _Shape0To4 | _HasShape[Shape4ToN, Shape0ToN], /) -> Self: ...
140+
class Rank2(_BaseRankM[_Shape0To2, Shape2ToN, Shape2], tuple[int, int]): ...
91141

92-
###
93-
# These emulate `AnyOf`, rather than a `Union`.
142+
@final
143+
@type_check_only
144+
class Rank3(_BaseRankM[_Shape0To3, Shape3ToN, Shape3], tuple[int, int, int]): ...
94145

95146
@final
96147
@type_check_only
97-
class Rank0ToN(tuple[int, ...], _HasShape[Shape0ToN, Shape0ToN]):
98-
def __broadcast__(self, to: Shape0ToN, from_: Shape0ToN, /) -> Self: ...
148+
class Rank4(_BaseRankM[_Shape0To4, Shape4ToN, Shape4], tuple[int, int, int, int]): ...
149+
150+
# this emulates `AnyOf`, rather than a `Union`.
151+
@type_check_only
152+
class _BaseRankMToN(_BaseRank[Shape0ToN, _OwnShapeT, _OwnShapeT], Generic[_OwnShapeT]): ...
99153

100154
@final
101155
@type_check_only
102-
class Rank1ToN(tuple[int, *tuple[int, ...]], _HasShape[Shape1ToN, Shape1ToN]):
103-
def __broadcast__(self, to: Shape1ToN, from_: Shape0ToN, /) -> Self: ...
156+
class Rank(_BaseRankMToN[Shape0ToN], tuple[int, ...]): ...
104157

105158
@final
106159
@type_check_only
107-
class Rank2ToN(tuple[int, int, *tuple[int, ...]], _HasShape[Shape2ToN, Shape2ToN]):
108-
def __broadcast__(self, to: Shape2ToN, from_: Shape0ToN, /) -> Self: ...
160+
class Rank1N(_BaseRankMToN[Shape1ToN], tuple[int, *tuple[int, ...]]): ...
109161

110162
@final
111163
@type_check_only
112-
class Rank3ToN(tuple[int, int, int, *tuple[int, ...]], _HasShape[Shape3ToN, Shape3ToN]):
113-
def __broadcast__(self, to: Shape3ToN, from_: Shape0ToN, /) -> Self: ...
164+
class Rank2N(_BaseRankMToN[Shape2ToN], tuple[int, int, *tuple[int, ...]]): ...
114165

115166
@final
116167
@type_check_only
117-
class Rank4ToN(tuple[int, int, int, int, *tuple[int, ...]], _HasShape[Shape4ToN, Shape4ToN]):
118-
def __broadcast__(self, to: Shape4ToN, from_: Shape0ToN, /) -> Self: ...
168+
class Rank3N(_BaseRankMToN[Shape3ToN], tuple[int, int, int, *tuple[int, ...]]): ...
119169

120-
###
170+
@final
171+
@type_check_only
172+
class Rank4N(_BaseRankMToN[Shape4ToN], tuple[int, int, int, int, *tuple[int, ...]]): ...
121173

122-
Broadcastable = TypeAliasType("Broadcastable", _CanBroadcast[_ToT, Any, _RankT], type_params=(_ToT, _RankT))
123-
Broadcaster = TypeAliasType("Broadcaster", _FromT | _CanBroadcast[Any, _FromT, _RankT], type_params=(_FromT, _RankT))
174+
Rank0N: TypeAlias = Rank

src/_numtype/_shape.pyi

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
1+
from typing import TypeAlias
12
from typing_extensions import TypeAliasType
23

34
__all__ = [
45
"Shape",
56
"Shape0",
7+
"Shape0N",
68
"Shape1",
7-
"Shape1_",
9+
"Shape1N",
810
"Shape2",
9-
"Shape2_",
11+
"Shape2N",
1012
"Shape3",
11-
"Shape3_",
13+
"Shape3N",
1214
"Shape4",
13-
"Shape4_",
15+
"Shape4N",
16+
"ShapeN",
1417
]
1518

1619
Shape = TypeAliasType("Shape", tuple[int, ...])
20+
21+
ShapeN: TypeAlias = Shape
1722
Shape0 = TypeAliasType("Shape0", tuple[()])
1823
Shape1 = TypeAliasType("Shape1", tuple[int])
1924
Shape2 = TypeAliasType("Shape2", tuple[int, int])
2025
Shape3 = TypeAliasType("Shape3", tuple[int, int, int])
2126
Shape4 = TypeAliasType("Shape4", tuple[int, int, int, int])
2227

23-
Shape1_ = TypeAliasType("Shape1_", tuple[int, *tuple[int, ...]])
24-
Shape2_ = TypeAliasType("Shape2_", tuple[int, int, *tuple[int, ...]])
25-
Shape3_ = TypeAliasType("Shape3_", tuple[int, int, int, *tuple[int, ...]])
26-
Shape4_ = TypeAliasType("Shape4_", tuple[int, int, int, int, *tuple[int, ...]])
28+
Shape0N: TypeAlias = Shape
29+
Shape1N = TypeAliasType("Shape1N", tuple[int, *tuple[int, ...]])
30+
Shape2N = TypeAliasType("Shape2N", tuple[int, int, *tuple[int, ...]])
31+
Shape3N = TypeAliasType("Shape3N", tuple[int, int, int, *tuple[int, ...]])
32+
Shape4N = TypeAliasType("Shape4N", tuple[int, int, int, int, *tuple[int, ...]])

src/numpy-stubs/@test/static/accept/linalg.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import numpy.typing as npt
77
from numpy.linalg._linalg import EigResult, EighResult, QRResult, SVDResult, SlogdetResult
88

99
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
10-
_Array2ND: TypeAlias = _nt.Array[_ScalarT, _nt.Shape2_]
10+
_Array2ND: TypeAlias = _nt.Array[_ScalarT, _nt.Shape2N]
1111

1212
###
1313

src/numpy-stubs/@test/static/accept/scalars.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ assert_type(u8.reshape(1, 1), np.ndarray[_nt.Shape2, np.dtype[np.uint64]])
115115
assert_type(f8.reshape(1, -1), np.ndarray[_nt.Shape2, np.dtype[np.float64]])
116116
assert_type(c16.reshape(1, 1, 1), np.ndarray[_nt.Shape3, np.dtype[np.complex128]])
117117
assert_type(U.reshape(1, 1, 1, 1), np.ndarray[_nt.Shape4, np.dtype[np.str_]])
118-
assert_type(S.reshape(1, 1, 1, 1, 1), np.ndarray[_nt.Shape4_, np.dtype[np.bytes_]])
118+
assert_type(S.reshape(1, 1, 1, 1, 1), np.ndarray[_nt.Shape4N, np.dtype[np.bytes_]])
119119

120120
assert_type(i8.astype(float), Any)
121121
assert_type(i8.astype(np.float64), np.float64)

src/numpy-stubs/__init__.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ _NumericArrayT = TypeVar("_NumericArrayT", bound=NDArray[number | timedelta64 |
629629

630630
_ShapeT = TypeVar("_ShapeT", bound=_nt.Shape)
631631
_ShapeT_co = TypeVar("_ShapeT_co", bound=_nt.Shape, covariant=True)
632-
_ShapeT_1nd = TypeVar("_ShapeT_1nd", bound=_nt.Shape1_)
632+
_ShapeT_1nd = TypeVar("_ShapeT_1nd", bound=_nt.Shape1N)
633633
_1NShapeT = TypeVar("_1NShapeT", bound=tuple[L[1], *tuple[L[1], ...]]) # TODO(jorenham): remove
634634

635635
_ScalarT = TypeVar("_ScalarT", bound=generic)
@@ -2055,7 +2055,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
20552055
@overload # == 1-d
20562056
def __iter__(self: _nt.Array1D[_ScalarT], /) -> Iterator[_ScalarT]: ...
20572057
@overload # >= 2-d
2058-
def __iter__(self: ndarray[_nt.Shape2_, dtype[_ScalarT]], /) -> Iterator[NDArray[_ScalarT]]: ...
2058+
def __iter__(self: ndarray[_nt.Shape2N, dtype[_ScalarT]], /) -> Iterator[NDArray[_ScalarT]]: ...
20592059
@overload # ?-d
20602060
def __iter__(self, /) -> Iterator[Any]: ...
20612061

@@ -3790,7 +3790,7 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
37903790
*sizes5_: CanIndex,
37913791
order: _OrderACF = "C",
37923792
copy: py_bool | None = None,
3793-
) -> _nt.Array[Self, _nt.Shape4_]: ...
3793+
) -> _nt.Array[Self, _nt.Shape4N]: ...
37943794

37953795
#
37963796
@overload

src/numpy-stubs/_core/_multiarray_umath.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ _SafeScalarT = TypeVar("_SafeScalarT", bound=_nt.co_complex | np.timedelta64 | n
7171

7272
_ArrayT = TypeVar("_ArrayT", bound=_nt.Array)
7373
_ArrayT_co = TypeVar("_ArrayT_co", bound=_nt.Array, default=_nt.Array, covariant=True)
74-
_Array1T = TypeVar("_Array1T", bound=_nt.Array[Any, _nt.Shape1_])
75-
_Array2T = TypeVar("_Array2T", bound=_nt.Array[Any, _nt.Shape2_])
74+
_Array1T = TypeVar("_Array1T", bound=_nt.Array[Any, _nt.Shape1N])
75+
_Array2T = TypeVar("_Array2T", bound=_nt.Array[Any, _nt.Shape2N])
7676

7777
###
7878

0 commit comments

Comments
 (0)