|
1 | 1 | from typing import Any, Generic, Protocol, Self, TypeAlias, final, type_check_only
|
2 | 2 | from typing_extensions import TypeAliasType, TypeVar, TypeVarTuple, override
|
3 | 3 |
|
4 |
| -from ._shape import Shape, Shape0, Shape0N, Shape1, Shape1N, Shape2, Shape2N, Shape3, Shape3N, Shape4, Shape4N |
| 4 | +from ._shape import AnyShape, Shape, Shape0, Shape1, Shape1N, Shape2, Shape2N, Shape3, Shape3N, Shape4, Shape4N |
5 | 5 |
|
6 | 6 | __all__ = [
|
7 | 7 | "HasInnerShape",
|
@@ -29,21 +29,24 @@ _Shape04: TypeAlias = _Shape03 | Shape4
|
29 | 29 |
|
30 | 30 | ###
|
31 | 31 |
|
32 |
| -_UpperT = TypeVar("_UpperT", bound=Shape) |
33 |
| -_LowerT = TypeVar("_LowerT", bound=Shape) |
| 32 | +# TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed |
| 33 | +_UpperT = TypeVar("_UpperT", bound=Shape | Rank0 | Rank) |
| 34 | +_LowerT = TypeVar("_LowerT", bound=Shape | Rank0 | Rank) |
34 | 35 | _RankT = TypeVar("_RankT", bound=Shape, default=Any)
|
35 | 36 |
|
36 |
| -_RankLE: TypeAlias = _CanBroadcast[Any, _UpperT, _RankT] |
37 |
| -_RankGE: TypeAlias = _CanBroadcast[_LowerT, Any, _RankT] |
| 37 | +# TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed |
| 38 | +_RankLE: TypeAlias = _CanBroadcast[Any, _UpperT, _RankT] | Shape0 | Rank0 | Rank |
| 39 | +# TODO(jorenham): remove `| Rank` once python/mypy#19110 is fixed |
| 40 | +_RankGE: TypeAlias = _CanBroadcast[_LowerT, Any, _RankT] | _LowerT | Rank |
38 | 41 |
|
39 | 42 | HasRankLE = TypeAliasType(
|
40 | 43 | "HasRankLE",
|
41 |
| - _HasInnerShape[Shape0 | _RankLE[_UpperT, _RankT]], |
| 44 | + _HasInnerShape[_RankLE[_UpperT, _RankT]], |
42 | 45 | type_params=(_UpperT, _RankT),
|
43 | 46 | )
|
44 | 47 | HasRankGE = TypeAliasType(
|
45 | 48 | "HasRankGE",
|
46 |
| - _HasInnerShape[_LowerT | _RankGE[_LowerT, _RankT]], |
| 49 | + _HasInnerShape[_RankGE[_LowerT, _RankT]], |
47 | 50 | type_params=(_LowerT, _RankT),
|
48 | 51 | )
|
49 | 52 |
|
@@ -136,30 +139,30 @@ class Rank4(BaseRank[int, int, int, int]):
|
136 | 139 | @type_check_only
|
137 | 140 | class Rank(BaseRank[*tuple[int, ...]]):
|
138 | 141 | @override
|
139 |
| - def __broadcast__(self, from_: Shape0N, to: tuple[*_Ts], /) -> Self: ... |
| 142 | + def __broadcast__(self, from_: AnyShape, to: tuple[*_Ts], /) -> Self: ... |
140 | 143 |
|
141 | 144 | @final
|
142 | 145 | @type_check_only
|
143 | 146 | class Rank1N(BaseRank[int, *tuple[int, ...]]):
|
144 | 147 | @override
|
145 |
| - def __broadcast__(self, from_: Shape0N, to: Shape1N, /) -> Self: ... |
| 148 | + def __broadcast__(self, from_: AnyShape, to: Shape1N, /) -> Self: ... |
146 | 149 |
|
147 | 150 | @final
|
148 | 151 | @type_check_only
|
149 | 152 | class Rank2N(BaseRank[int, int, *tuple[int, ...]]):
|
150 | 153 | @override
|
151 |
| - def __broadcast__(self, from_: Shape0N, to: Shape2N, /) -> Self: ... |
| 154 | + def __broadcast__(self, from_: AnyShape, to: Shape2N, /) -> Self: ... |
152 | 155 |
|
153 | 156 | @final
|
154 | 157 | @type_check_only
|
155 | 158 | class Rank3N(BaseRank[int, int, int, *tuple[int, ...]]):
|
156 | 159 | @override
|
157 |
| - def __broadcast__(self, from_: Shape0N, to: Shape3N, /) -> Self: ... |
| 160 | + def __broadcast__(self, from_: AnyShape, to: Shape3N, /) -> Self: ... |
158 | 161 |
|
159 | 162 | @final
|
160 | 163 | @type_check_only
|
161 | 164 | class Rank4N(BaseRank[int, int, int, int, *tuple[int, ...]]):
|
162 | 165 | @override
|
163 |
| - def __broadcast__(self, from_: Shape0N, to: Shape4N, /) -> Self: ... |
| 166 | + def __broadcast__(self, from_: AnyShape, to: Shape4N, /) -> Self: ... |
164 | 167 |
|
165 | 168 | Rank0N: TypeAlias = Rank
|
0 commit comments