1
1
from typing import Any , Generic , Protocol , Self , TypeAlias , final , type_check_only
2
- from typing_extensions import TypeAliasType , TypeVar
2
+ from typing_extensions import TypeAliasType , TypeVar , TypeVarTuple , override
3
3
4
4
from ._shape import Shape , Shape0 , Shape0N , Shape1 , Shape1N , Shape2 , Shape2N , Shape3 , Shape3N , Shape4 , Shape4N
5
5
6
6
__all__ = [
7
+ "HasInnerShape" ,
7
8
"HasRankGE" ,
8
9
"HasRankLE" ,
9
10
"Rank" ,
@@ -21,8 +22,7 @@ __all__ = [
21
22
22
23
###
23
24
24
- _Shape00 : TypeAlias = Shape0
25
- _Shape01 : TypeAlias = _Shape00 | Shape1
25
+ _Shape01 : TypeAlias = Shape0 | Shape1
26
26
_Shape02 : TypeAlias = _Shape01 | Shape2
27
27
_Shape03 : TypeAlias = _Shape02 | Shape3
28
28
_Shape04 : TypeAlias = _Shape03 | Shape4
@@ -33,44 +33,57 @@ _UpperT = TypeVar("_UpperT", bound=Shape)
33
33
_LowerT = TypeVar ("_LowerT" , bound = Shape )
34
34
_RankT = TypeVar ("_RankT" , bound = Shape , default = Any )
35
35
36
+ _RankLE : TypeAlias = _CanBroadcast [Any , _UpperT , _RankT ]
37
+ _RankGE : TypeAlias = _CanBroadcast [_LowerT , Any , _RankT ]
38
+
36
39
HasRankLE = TypeAliasType (
37
40
"HasRankLE" ,
38
- _HasShape [Shape0 | _HasOwnShape [ _UpperT ] | _CanBroadcast [ Any , _UpperT , _RankT ]],
41
+ _HasInnerShape [Shape0 | _RankLE [ _UpperT , _RankT ]],
39
42
type_params = (_UpperT , _RankT ),
40
43
)
41
44
HasRankGE = TypeAliasType (
42
45
"HasRankGE" ,
43
- _HasShape [_LowerT | _CanBroadcast [_LowerT , Any , _RankT ]],
46
+ _HasInnerShape [_LowerT | _RankGE [_LowerT , _RankT ]],
44
47
type_params = (_LowerT , _RankT ),
45
48
)
46
49
47
- ###
50
+ _ShapeT = TypeVar ( "_ShapeT" , bound = Shape )
48
51
49
- _ShapeT_co = TypeVar ("_ShapeT_co" , bound = Shape | _HasOwnShape | _CanBroadcast , covariant = True )
52
+ # for unwrapping potential rank types as shape tuples
53
+ HasInnerShape = TypeAliasType (
54
+ "HasInnerShape" ,
55
+ _HasInnerShape [_HasOwnShape [Any , _ShapeT ]],
56
+ type_params = (_ShapeT ,),
57
+ )
50
58
51
- @type_check_only
52
- class _HasShape (Protocol [_ShapeT_co ]):
53
- @property
54
- def shape (self , / ) -> _ShapeT_co : ...
59
+ ###
60
+
61
+ _ShapeLikeT_co = TypeVar ("_ShapeLikeT_co" , bound = Shape | _HasOwnShape | _CanBroadcast [Any , Any ], covariant = True )
55
62
56
- _FromT_contra = TypeVar ("_FromT_contra" , default = Any , contravariant = True )
57
- _ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , default = Any , contravariant = True )
63
+ _FromT_contra = TypeVar ("_FromT_contra" , contravariant = True )
64
+ _ToT_contra = TypeVar ("_ToT_contra" , bound = tuple [ Any , ...] , contravariant = True )
58
65
_EquivT_co = TypeVar ("_EquivT_co" , bound = Shape , default = Any , covariant = True )
59
66
67
+ # __broadcast__ is the type-check-only interface order of ranks
60
68
@final
61
69
@type_check_only
62
70
class _CanBroadcast (Protocol [_FromT_contra , _ToT_contra , _EquivT_co ]):
63
71
def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> _EquivT_co : ...
64
72
73
+ # __inner_shape__ is similar to `shape`, but directly exposes the `Rank` type.
74
+ @final
75
+ @type_check_only
76
+ class _HasInnerShape (Protocol [_ShapeLikeT_co ]):
77
+ @property
78
+ def __inner_shape__ (self , / ) -> _ShapeLikeT_co : ...
79
+
80
+ _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = tuple [Any , ...], default = Any , contravariant = True )
81
+ _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
82
+
65
83
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
66
84
# e.g. `_HasOwnShape[Shape2N, Shape0N]` accepts `Shape2N`, `Shape1N`, and `Shape0N`, but
67
85
# rejects `Shape3N` and `Shape1`. Besides brevity, it also works around several mypy bugs that
68
86
# are related to "unions vs joins".
69
-
70
- _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = Shape , default = Any , contravariant = True )
71
- _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
72
- _OwnShapeT = TypeVar ("_OwnShapeT" , bound = tuple [Any , ...], default = Any )
73
-
74
87
@final
75
88
@type_check_only
76
89
class _HasOwnShape (Protocol [_OwnShapeT_contra , _OwnShapeT_co ]):
@@ -79,59 +92,74 @@ class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
79
92
###
80
93
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
81
94
82
- @type_check_only
83
- class _BaseRank (Generic [_FromT_contra , _OwnShapeT , _ToT_contra ]):
84
- def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> Self : ...
85
- def __own_shape__ (self , shape : _OwnShapeT , / ) -> _OwnShapeT : ...
95
+ _Ts = TypeVarTuple ("_Ts" ) # should only contain `int`s
86
96
97
+ # https://github.com/python/mypy/issues/19093
87
98
@type_check_only
88
- class _BaseRankM (
89
- _BaseRank [_FromT_contra | _HasOwnShape [_ToT_contra , Shape ], _OwnShapeT , _ToT_contra ],
90
- Generic [_FromT_contra , _OwnShapeT , _ToT_contra ],
91
- ): ...
99
+ class BaseRank (tuple [* _Ts ], Generic [* _Ts ]):
100
+ def __broadcast__ (self , from_ : tuple [* _Ts ], to : tuple [* _Ts ], / ) -> Self : ...
101
+ def __own_shape__ (self , shape : tuple [* _Ts ], / ) -> tuple [* _Ts ]: ...
92
102
93
103
@final
94
104
@type_check_only
95
- class Rank0 (_BaseRankM [_Shape00 , Shape0 , Shape0N ], tuple [()]): ...
105
+ class Rank0 (BaseRank [* tuple [()]]):
106
+ @override
107
+ def __broadcast__ (self , from_ : Shape0 | _HasOwnShape [Shape , Any ], to : Shape , / ) -> Self : ...
96
108
97
109
@final
98
110
@type_check_only
99
- class Rank1 (_BaseRankM [_Shape01 , Shape1 , Shape1N ], tuple [int ]): ...
111
+ class Rank1 (BaseRank [int ]):
112
+ @override
113
+ def __broadcast__ (self , from_ : _Shape01 | _HasOwnShape [Shape1N , Any ], to : Shape1N , / ) -> Self : ...
100
114
101
115
@final
102
116
@type_check_only
103
- class Rank2 (_BaseRankM [_Shape02 , Shape2 , Shape2N ], tuple [int , int ]): ...
117
+ class Rank2 (BaseRank [int , int ]):
118
+ @override
119
+ def __broadcast__ (self , from_ : _Shape02 | _HasOwnShape [Shape2N , Any ], to : Shape2N , / ) -> Self : ...
104
120
105
121
@final
106
122
@type_check_only
107
- class Rank3 (_BaseRankM [_Shape03 , Shape3 , Shape3N ], tuple [int , int , int ]): ...
123
+ class Rank3 (BaseRank [int , int , int ]):
124
+ @override
125
+ def __broadcast__ (self , from_ : _Shape03 | _HasOwnShape [Shape3N , Any ], to : Shape3N , / ) -> Self : ...
108
126
109
127
@final
110
128
@type_check_only
111
- class Rank4 (_BaseRankM [_Shape04 , Shape4 , Shape4N ], tuple [int , int , int , int ]): ...
129
+ class Rank4 (BaseRank [int , int , int , int ]):
130
+ @override
131
+ def __broadcast__ (self , from_ : _Shape04 | _HasOwnShape [Shape4N , Any ], to : Shape4N , / ) -> Self : ...
112
132
113
- # this emulates `AnyOf`, rather than a `Union`.
114
- @type_check_only
115
- class _BaseRankMToN (_BaseRank [Shape0N , _OwnShapeT , _OwnShapeT ], Generic [_OwnShapeT ]): ...
133
+ # these emulates `AnyOf` (gradual union), rather than a `Union`.
116
134
117
135
@final
118
136
@type_check_only
119
- class Rank (_BaseRankMToN [Shape0N ], tuple [int , ...]): ...
137
+ class Rank (BaseRank [* tuple [int , ...]]):
138
+ @override
139
+ def __broadcast__ (self , from_ : Shape0N , to : tuple [* _Ts ], / ) -> Self : ...
120
140
121
141
@final
122
142
@type_check_only
123
- class Rank1N (_BaseRankMToN [Shape1N ], tuple [int , * tuple [int , ...]]): ...
143
+ class Rank1N (BaseRank [int , * tuple [int , ...]]):
144
+ @override
145
+ def __broadcast__ (self , from_ : Shape0N , to : Shape1N , / ) -> Self : ...
124
146
125
147
@final
126
148
@type_check_only
127
- class Rank2N (_BaseRankMToN [Shape2N ], tuple [int , int , * tuple [int , ...]]): ...
149
+ class Rank2N (BaseRank [int , int , * tuple [int , ...]]):
150
+ @override
151
+ def __broadcast__ (self , from_ : Shape0N , to : Shape2N , / ) -> Self : ...
128
152
129
153
@final
130
154
@type_check_only
131
- class Rank3N (_BaseRankMToN [Shape3N ], tuple [int , int , int , * tuple [int , ...]]): ...
155
+ class Rank3N (BaseRank [int , int , int , * tuple [int , ...]]):
156
+ @override
157
+ def __broadcast__ (self , from_ : Shape0N , to : Shape3N , / ) -> Self : ...
132
158
133
159
@final
134
160
@type_check_only
135
- class Rank4N (_BaseRankMToN [Shape4N ], tuple [int , int , int , int , * tuple [int , ...]]): ...
161
+ class Rank4N (BaseRank [int , int , int , int , * tuple [int , ...]]):
162
+ @override
163
+ def __broadcast__ (self , from_ : Shape0N , to : Shape4N , / ) -> Self : ...
136
164
137
165
Rank0N : TypeAlias = Rank
0 commit comments