1
- from typing import Any , Protocol , Self , TypeAlias , final , type_check_only
1
+ from typing import Any , Generic , Protocol , Self , TypeAlias , final , type_check_only
2
2
from typing_extensions import TypeAliasType , TypeVar
3
3
4
4
from ._shape import (
5
5
Shape ,
6
6
Shape as Shape0ToN ,
7
7
Shape0 ,
8
8
Shape1 ,
9
- Shape1_ as Shape1ToN ,
9
+ Shape1N as Shape1ToN ,
10
10
Shape2 ,
11
- Shape2_ as Shape2ToN ,
11
+ Shape2N as Shape2ToN ,
12
12
Shape3 ,
13
- Shape3_ as Shape3ToN ,
13
+ Shape3N as Shape3ToN ,
14
14
Shape4 ,
15
- Shape4_ as Shape4ToN ,
15
+ Shape4N as Shape4ToN ,
16
16
)
17
17
18
18
__all__ = [
19
- "Broadcastable" ,
19
+ "Broadcasts" ,
20
+ "BroadcastsTo" ,
21
+ "Rank" ,
20
22
"Rank0" ,
21
- "Rank0ToN " ,
23
+ "Rank0N " ,
22
24
"Rank1" ,
23
- "Rank1ToN " ,
25
+ "Rank1N " ,
24
26
"Rank2" ,
25
- "Rank2ToN " ,
27
+ "Rank2N " ,
26
28
"Rank3" ,
27
- "Rank3ToN " ,
29
+ "Rank3N " ,
28
30
"Rank4" ,
29
- "Rank4ToN " ,
31
+ "Rank4N " ,
30
32
]
31
33
32
34
###
33
35
34
- _ToT = TypeVar ("_ToT" , bound = Shape )
35
- _ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , contravariant = True )
36
- _FromT = TypeVar ("_FromT" , bound = Shape )
37
- _FromT_contra = TypeVar ("_FromT_contra" , bound = Shape , default = Any , contravariant = True )
38
- _RankT = TypeVar ("_RankT" , bound = _HasShape [Any ], default = Any )
39
- _RankT_co = TypeVar ("_RankT_co" , default = Any , covariant = True )
40
- _ShapeT_contra = TypeVar ("_ShapeT_contra" , contravariant = True )
41
- _ShapeT_co = TypeVar ("_ShapeT_co" , covariant = True , default = _ShapeT_contra )
42
-
43
- ###
44
-
45
36
_Shape0To0 : TypeAlias = Shape0
46
37
_Shape0To1 : TypeAlias = _Shape0To0 | Shape1
47
38
_Shape0To2 : TypeAlias = _Shape0To1 | Shape2
@@ -50,74 +41,134 @@ _Shape0To4: TypeAlias = _Shape0To3 | Shape4
50
41
51
42
###
52
43
44
+ _ToT = TypeVar ("_ToT" , bound = Shape )
45
+ _FromT = TypeVar ("_FromT" , bound = Shape )
46
+ _RankT = TypeVar ("_RankT" , bound = Shape , default = Any )
47
+
48
+ _BroadcastableShape = TypeAliasType (
49
+ "_BroadcastableShape" ,
50
+ _FromT | _CanBroadcastFrom [_FromT , _RankT ],
51
+ type_params = (_FromT , _RankT ),
52
+ )
53
+
54
+ BroadcastsTo = TypeAliasType (
55
+ "BroadcastsTo" ,
56
+ _HasRank [_CanBroadcastTo [_ToT , _RankT ]],
57
+ type_params = (_ToT , _RankT ),
58
+ )
59
+ Broadcasts = TypeAliasType (
60
+ "Broadcasts" ,
61
+ _HasRank [_BroadcastableShape [_FromT , _RankT ]],
62
+ type_params = (_FromT , _RankT ),
63
+ )
64
+
65
+ ###
66
+
67
+ _ShapeT_co = TypeVar (
68
+ "_ShapeT_co" ,
69
+ bound = Shape | _HasOwnShape | _CanBroadcastFrom | _CanBroadcastTo ,
70
+ covariant = True ,
71
+ )
72
+
73
+ @type_check_only
74
+ class _HasShape (Protocol [_ShapeT_co ]):
75
+ @property
76
+ def shape (self , / ) -> _ShapeT_co : ...
77
+
78
+ _ShapeT = TypeVar ("_ShapeT" , bound = Shape )
79
+
80
+ @final
53
81
@type_check_only
54
- class _CanBroadcast (Protocol [_ToT_contra , _FromT_contra , _RankT_co ]):
55
- def __broadcast__ (self , to : _ToT_contra , from_ : _FromT_contra , / ) -> _RankT_co : ...
82
+ class _HasRank (Protocol [_ShapeT_co ]):
83
+ @property
84
+ def shape (self : _HasShape [_ShapeT ], / ) -> _ShapeT : ...
85
+
86
+ _FromT_contra = TypeVar ("_FromT_contra" , default = Any , contravariant = True )
87
+ _ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , default = Any , contravariant = True )
88
+ _RankT_co = TypeVar ("_RankT_co" , bound = Shape , default = Any , covariant = True )
89
+
90
+ @final
91
+ @type_check_only
92
+ class _CanBroadcastFrom (Protocol [_FromT_contra , _RankT_co ]):
93
+ def __broadcast_from__ (self , from_ : _FromT_contra , / ) -> _RankT_co : ...
94
+
95
+ @final
96
+ @type_check_only
97
+ class _CanBroadcastTo (Protocol [_ToT_contra , _RankT_co ]):
98
+ def __broadcast_to__ (self , to : _ToT_contra , / ) -> _RankT_co : ...
56
99
57
100
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
58
- # e.g. `_HasShape [Shape2ToN, Shape0ToN]` accepts `Shape2ToN`, `Shape1ToN`, and `Shape0ToN`, but
101
+ # e.g. `_HasOwnShape [Shape2ToN, Shape0ToN]` accepts `Shape2ToN`, `Shape1ToN`, and `Shape0ToN`, but
59
102
# rejects `Shape3ToN` and `Shape1`. Besides brevity, it also works around several mypy bugs that
60
103
# are related to "unions vs joins".
104
+
105
+ _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = Shape , default = Any , contravariant = True )
106
+ _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
107
+
108
+ @final
61
109
@type_check_only
62
- class _HasShape (Protocol [_ShapeT_contra , _ShapeT_co ]):
63
- def __shape__ (self , shape : _ShapeT_contra , / ) -> _ShapeT_co : ...
110
+ class _HasOwnShape (Protocol [_OwnShapeT_contra , _OwnShapeT_co ]):
111
+ def __own_shape__ (self , shape : _OwnShapeT_contra , / ) -> _OwnShapeT_co : ...
64
112
65
113
###
66
114
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
67
115
68
- @final
116
+ _OwnShapeT = TypeVar ("_OwnShapeT" , bound = tuple [Any , ...], default = Any )
117
+
69
118
@type_check_only
70
- class Rank0 (tuple [()], _HasShape [Shape0 , Shape0 ]):
71
- def __broadcast__ (self , to : Shape0ToN , from_ : _HasShape [Shape0ToN , Shape0ToN ], / ) -> Self : ...
119
+ class _BaseRank (Generic [_FromT_contra , _ToT_contra , _OwnShapeT ]):
120
+ def __broadcast_from__ (self , from_ : _FromT_contra , / ) -> Self : ...
121
+ def __broadcast_to__ (self , to : _ToT_contra , / ) -> Self : ...
122
+ def __own_shape__ (self , shape : _OwnShapeT , / ) -> _OwnShapeT : ...
72
123
73
124
@type_check_only
74
- class Rank1 (tuple [int ], _HasShape [Shape1 , Shape1 ]):
75
- def __broadcast__ (self , to : Shape1ToN , from_ : _Shape0To1 | _HasShape [Shape1ToN , Shape0ToN ], / ) -> Self : ...
125
+ class _BaseRankM (
126
+ _BaseRank [_FromT_contra | _HasOwnShape [_ToT_contra , Shape ], _ToT_contra , _OwnShapeT ],
127
+ Generic [_FromT_contra , _ToT_contra , _OwnShapeT ],
128
+ ): ...
76
129
77
130
@final
78
131
@type_check_only
79
- class Rank2 (tuple [int , int ], _HasShape [Shape2 , Shape2 ]):
80
- def __broadcast__ (self , to : Shape2ToN , from_ : _Shape0To2 | _HasShape [Shape2ToN , Shape0ToN ], / ) -> Self : ...
132
+ class Rank0 (_BaseRankM [_Shape0To0 , Shape0ToN , Shape0 ], tuple [()]): ...
81
133
82
134
@final
83
135
@type_check_only
84
- class Rank3 (tuple [int , int , int ], _HasShape [Shape3 , Shape3 ]):
85
- def __broadcast__ (self , to : Shape3ToN , from_ : _Shape0To3 | _HasShape [Shape3ToN , Shape0ToN ], / ) -> Self : ...
136
+ class Rank1 (_BaseRankM [_Shape0To1 , Shape1ToN , Shape1 ], tuple [int ]): ...
86
137
87
138
@final
88
139
@type_check_only
89
- class Rank4 (tuple [int , int , int , int ], _HasShape [Shape4 , Shape4 ]):
90
- def __broadcast__ (self , to : Shape4ToN , from_ : _Shape0To4 | _HasShape [Shape4ToN , Shape0ToN ], / ) -> Self : ...
140
+ class Rank2 (_BaseRankM [_Shape0To2 , Shape2ToN , Shape2 ], tuple [int , int ]): ...
91
141
92
- ###
93
- # These emulate `AnyOf`, rather than a `Union`.
142
+ @final
143
+ @type_check_only
144
+ class Rank3 (_BaseRankM [_Shape0To3 , Shape3ToN , Shape3 ], tuple [int , int , int ]): ...
94
145
95
146
@final
96
147
@type_check_only
97
- class Rank0ToN (tuple [int , ...], _HasShape [Shape0ToN , Shape0ToN ]):
98
- def __broadcast__ (self , to : Shape0ToN , from_ : Shape0ToN , / ) -> Self : ...
148
+ class Rank4 (_BaseRankM [_Shape0To4 , Shape4ToN , Shape4 ], tuple [int , int , int , int ]): ...
149
+
150
+ # this emulates `AnyOf`, rather than a `Union`.
151
+ @type_check_only
152
+ class _BaseRankMToN (_BaseRank [Shape0ToN , _OwnShapeT , _OwnShapeT ], Generic [_OwnShapeT ]): ...
99
153
100
154
@final
101
155
@type_check_only
102
- class Rank1ToN (tuple [int , * tuple [int , ...]], _HasShape [Shape1ToN , Shape1ToN ]):
103
- def __broadcast__ (self , to : Shape1ToN , from_ : Shape0ToN , / ) -> Self : ...
156
+ class Rank (_BaseRankMToN [Shape0ToN ], tuple [int , ...]): ...
104
157
105
158
@final
106
159
@type_check_only
107
- class Rank2ToN (tuple [int , int , * tuple [int , ...]], _HasShape [Shape2ToN , Shape2ToN ]):
108
- def __broadcast__ (self , to : Shape2ToN , from_ : Shape0ToN , / ) -> Self : ...
160
+ class Rank1N (_BaseRankMToN [Shape1ToN ], tuple [int , * tuple [int , ...]]): ...
109
161
110
162
@final
111
163
@type_check_only
112
- class Rank3ToN (tuple [int , int , int , * tuple [int , ...]], _HasShape [Shape3ToN , Shape3ToN ]):
113
- def __broadcast__ (self , to : Shape3ToN , from_ : Shape0ToN , / ) -> Self : ...
164
+ class Rank2N (_BaseRankMToN [Shape2ToN ], tuple [int , int , * tuple [int , ...]]): ...
114
165
115
166
@final
116
167
@type_check_only
117
- class Rank4ToN (tuple [int , int , int , int , * tuple [int , ...]], _HasShape [Shape4ToN , Shape4ToN ]):
118
- def __broadcast__ (self , to : Shape4ToN , from_ : Shape0ToN , / ) -> Self : ...
168
+ class Rank3N (_BaseRankMToN [Shape3ToN ], tuple [int , int , int , * tuple [int , ...]]): ...
119
169
120
- ###
170
+ @final
171
+ @type_check_only
172
+ class Rank4N (_BaseRankMToN [Shape4ToN ], tuple [int , int , int , int , * tuple [int , ...]]): ...
121
173
122
- Broadcastable = TypeAliasType ("Broadcastable" , _CanBroadcast [_ToT , Any , _RankT ], type_params = (_ToT , _RankT ))
123
- Broadcaster = TypeAliasType ("Broadcaster" , _FromT | _CanBroadcast [Any , _FromT , _RankT ], type_params = (_FromT , _RankT ))
174
+ Rank0N : TypeAlias = Rank
0 commit comments