@@ -4,8 +4,8 @@ from typing_extensions import TypeAliasType, TypeVar
4
4
from ._shape import Shape , Shape0 , Shape0N , Shape1 , Shape1N , Shape2 , Shape2N , Shape3 , Shape3N , Shape4 , Shape4N
5
5
6
6
__all__ = [
7
- "Broadcasts " ,
8
- "BroadcastsTo " ,
7
+ "HasRankGE " ,
8
+ "HasRankLE " ,
9
9
"Rank" ,
10
10
"Rank0" ,
11
11
"Rank0N" ,
@@ -29,22 +29,24 @@ _Shape04: TypeAlias = _Shape03 | Shape4
29
29
30
30
###
31
31
32
- _ToT = TypeVar ("_ToT " , bound = Shape )
33
- _FromT = TypeVar ("_FromT " , bound = Shape )
32
+ _UpperT = TypeVar ("_UpperT " , bound = Shape )
33
+ _LowerT = TypeVar ("_LowerT " , bound = Shape )
34
34
_RankT = TypeVar ("_RankT" , bound = Shape , default = Any )
35
35
36
- _BroadcastableShape = TypeAliasType (
37
- "_BroadcastableShape" ,
38
- _FromT | _CanBroadcastFrom [_FromT , _RankT ],
39
- type_params = (_FromT , _RankT ),
36
+ HasRankLE = TypeAliasType (
37
+ "HasRankLE" ,
38
+ _HasShape [_HasOwnShape [_UpperT ] | _CanBroadcast [Any , _UpperT , _RankT ]],
39
+ type_params = (_UpperT , _RankT ),
40
+ )
41
+ HasRankGE = TypeAliasType (
42
+ "HasRankGE" ,
43
+ _HasShape [_LowerT | _CanBroadcast [_LowerT , Any , _RankT ]],
44
+ type_params = (_LowerT , _RankT ),
40
45
)
41
-
42
- BroadcastsTo = TypeAliasType ("BroadcastsTo" , _HasShape [_CanBroadcastTo [_ToT , _RankT ]], type_params = (_ToT , _RankT ))
43
- Broadcasts = TypeAliasType ("Broadcasts" , _HasShape [_BroadcastableShape [_FromT , _RankT ]], type_params = (_FromT , _RankT ))
44
46
45
47
###
46
48
47
- _ShapeT_co = TypeVar ("_ShapeT_co" , bound = Shape | _HasOwnShape | _CanBroadcastFrom | _CanBroadcastTo , covariant = True )
49
+ _ShapeT_co = TypeVar ("_ShapeT_co" , bound = Shape | _HasOwnShape | _CanBroadcast , covariant = True )
48
50
49
51
@type_check_only
50
52
class _HasShape (Protocol [_ShapeT_co ]):
@@ -53,17 +55,12 @@ class _HasShape(Protocol[_ShapeT_co]):
53
55
54
56
_FromT_contra = TypeVar ("_FromT_contra" , default = Any , contravariant = True )
55
57
_ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , default = Any , contravariant = True )
56
- _RankT_co = TypeVar ("_RankT_co" , bound = Shape , default = Any , covariant = True )
57
-
58
- @final
59
- @type_check_only
60
- class _CanBroadcastFrom (Protocol [_FromT_contra , _RankT_co ]):
61
- def __broadcast_from__ (self , from_ : _FromT_contra , / ) -> _RankT_co : ...
58
+ _EquivT_co = TypeVar ("_EquivT_co" , bound = Shape , default = Any , covariant = True )
62
59
63
60
@final
64
61
@type_check_only
65
- class _CanBroadcastTo (Protocol [_ToT_contra , _RankT_co ]):
66
- def __broadcast_to__ (self , to : _ToT_contra , / ) -> _RankT_co : ...
62
+ class _CanBroadcast (Protocol [_FromT_contra , _ToT_contra , _EquivT_co ]):
63
+ def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> _EquivT_co : ...
67
64
68
65
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
69
66
# e.g. `_HasOwnShape[Shape2N, Shape0N]` accepts `Shape2N`, `Shape1N`, and `Shape0N`, but
@@ -72,6 +69,7 @@ class _CanBroadcastTo(Protocol[_ToT_contra, _RankT_co]):
72
69
73
70
_OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = Shape , default = Any , contravariant = True )
74
71
_OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
72
+ _OwnShapeT = TypeVar ("_OwnShapeT" , bound = tuple [Any , ...], default = Any )
75
73
76
74
@final
77
75
@type_check_only
@@ -81,39 +79,36 @@ class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
81
79
###
82
80
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
83
81
84
- _OwnShapeT = TypeVar ("_OwnShapeT" , bound = tuple [Any , ...], default = Any )
85
-
86
82
@type_check_only
87
- class _BaseRank (Generic [_FromT_contra , _ToT_contra , _OwnShapeT ]):
88
- def __broadcast_from__ (self , from_ : _FromT_contra , / ) -> Self : ...
89
- def __broadcast_to__ (self , to : _ToT_contra , / ) -> Self : ...
83
+ class _BaseRank (Generic [_FromT_contra , _OwnShapeT , _ToT_contra ]):
84
+ def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> Self : ...
90
85
def __own_shape__ (self , shape : _OwnShapeT , / ) -> _OwnShapeT : ...
91
86
92
87
@type_check_only
93
88
class _BaseRankM (
94
- _BaseRank [_FromT_contra | _HasOwnShape [_ToT_contra , Shape ], _ToT_contra , _OwnShapeT ],
95
- Generic [_FromT_contra , _ToT_contra , _OwnShapeT ],
89
+ _BaseRank [_FromT_contra | _HasOwnShape [_ToT_contra , Shape ], _OwnShapeT , _ToT_contra ],
90
+ Generic [_FromT_contra , _OwnShapeT , _ToT_contra ],
96
91
): ...
97
92
98
93
@final
99
94
@type_check_only
100
- class Rank0 (_BaseRankM [_Shape00 , Shape0N , Shape0 ], tuple [()]): ...
95
+ class Rank0 (_BaseRankM [_Shape00 , Shape0 , Shape0N ], tuple [()]): ...
101
96
102
97
@final
103
98
@type_check_only
104
- class Rank1 (_BaseRankM [_Shape01 , Shape1N , Shape1 ], tuple [int ]): ...
99
+ class Rank1 (_BaseRankM [_Shape01 , Shape1 , Shape1N ], tuple [int ]): ...
105
100
106
101
@final
107
102
@type_check_only
108
- class Rank2 (_BaseRankM [_Shape02 , Shape2N , Shape2 ], tuple [int , int ]): ...
103
+ class Rank2 (_BaseRankM [_Shape02 , Shape2 , Shape2N ], tuple [int , int ]): ...
109
104
110
105
@final
111
106
@type_check_only
112
- class Rank3 (_BaseRankM [_Shape03 , Shape3N , Shape3 ], tuple [int , int , int ]): ...
107
+ class Rank3 (_BaseRankM [_Shape03 , Shape3 , Shape3N ], tuple [int , int , int ]): ...
113
108
114
109
@final
115
110
@type_check_only
116
- class Rank4 (_BaseRankM [_Shape04 , Shape4N , Shape4 ], tuple [int , int , int , int ]): ...
111
+ class Rank4 (_BaseRankM [_Shape04 , Shape4 , Shape4N ], tuple [int , int , int , int ]): ...
117
112
118
113
# this emulates `AnyOf`, rather than a `Union`.
119
114
@type_check_only
0 commit comments