Skip to content

Commit 197cf6d

Browse files
committed
HasRank{LE,GE} type aliases for for broadcasting things with a shape
1 parent 460f64c commit 197cf6d

File tree

2 files changed

+29
-35
lines changed

2 files changed

+29
-35
lines changed

src/_numtype/__init__.pyi

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ from ._nep50 import (
8484
CastsWithScalar as CastsWithScalar,
8585
)
8686
from ._rank import (
87-
Broadcasts as Broadcasts,
88-
BroadcastsTo as BroadcastsTo,
87+
HasRankGE as HasRankGE,
88+
HasRankLE as HasRankLE,
8989
Rank as Rank,
9090
Rank0 as Rank0,
9191
Rank0N as Rank0N,
@@ -97,7 +97,6 @@ from ._rank import (
9797
Rank3N as Rank3N,
9898
Rank4 as Rank4,
9999
Rank4N as Rank4N,
100-
_BroadcastableShape as _BroadcastableShape,
101100
)
102101
from ._scalar import (
103102
inexact32 as inexact32,

src/_numtype/_rank.pyi

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ from typing_extensions import TypeAliasType, TypeVar
44
from ._shape import Shape, Shape0, Shape0N, Shape1, Shape1N, Shape2, Shape2N, Shape3, Shape3N, Shape4, Shape4N
55

66
__all__ = [
7-
"Broadcasts",
8-
"BroadcastsTo",
7+
"HasRankGE",
8+
"HasRankLE",
99
"Rank",
1010
"Rank0",
1111
"Rank0N",
@@ -29,22 +29,24 @@ _Shape04: TypeAlias = _Shape03 | Shape4
2929

3030
###
3131

32-
_ToT = TypeVar("_ToT", bound=Shape)
33-
_FromT = TypeVar("_FromT", bound=Shape)
32+
_UpperT = TypeVar("_UpperT", bound=Shape)
33+
_LowerT = TypeVar("_LowerT", bound=Shape)
3434
_RankT = TypeVar("_RankT", bound=Shape, default=Any)
3535

36-
_BroadcastableShape = TypeAliasType(
37-
"_BroadcastableShape",
38-
_FromT | _CanBroadcastFrom[_FromT, _RankT],
39-
type_params=(_FromT, _RankT),
36+
HasRankLE = TypeAliasType(
37+
"HasRankLE",
38+
_HasShape[_HasOwnShape[_UpperT] | _CanBroadcast[Any, _UpperT, _RankT]],
39+
type_params=(_UpperT, _RankT),
40+
)
41+
HasRankGE = TypeAliasType(
42+
"HasRankGE",
43+
_HasShape[_LowerT | _CanBroadcast[_LowerT, Any, _RankT]],
44+
type_params=(_LowerT, _RankT),
4045
)
41-
42-
BroadcastsTo = TypeAliasType("BroadcastsTo", _HasShape[_CanBroadcastTo[_ToT, _RankT]], type_params=(_ToT, _RankT))
43-
Broadcasts = TypeAliasType("Broadcasts", _HasShape[_BroadcastableShape[_FromT, _RankT]], type_params=(_FromT, _RankT))
4446

4547
###
4648

47-
_ShapeT_co = TypeVar("_ShapeT_co", bound=Shape | _HasOwnShape | _CanBroadcastFrom | _CanBroadcastTo, covariant=True)
49+
_ShapeT_co = TypeVar("_ShapeT_co", bound=Shape | _HasOwnShape | _CanBroadcast, covariant=True)
4850

4951
@type_check_only
5052
class _HasShape(Protocol[_ShapeT_co]):
@@ -53,17 +55,12 @@ class _HasShape(Protocol[_ShapeT_co]):
5355

5456
_FromT_contra = TypeVar("_FromT_contra", default=Any, contravariant=True)
5557
_ToT_contra = TypeVar("_ToT_contra", bound=Shape, default=Any, contravariant=True)
56-
_RankT_co = TypeVar("_RankT_co", bound=Shape, default=Any, covariant=True)
57-
58-
@final
59-
@type_check_only
60-
class _CanBroadcastFrom(Protocol[_FromT_contra, _RankT_co]):
61-
def __broadcast_from__(self, from_: _FromT_contra, /) -> _RankT_co: ...
58+
_EquivT_co = TypeVar("_EquivT_co", bound=Shape, default=Any, covariant=True)
6259

6360
@final
6461
@type_check_only
65-
class _CanBroadcastTo(Protocol[_ToT_contra, _RankT_co]):
66-
def __broadcast_to__(self, to: _ToT_contra, /) -> _RankT_co: ...
62+
class _CanBroadcast(Protocol[_FromT_contra, _ToT_contra, _EquivT_co]):
63+
def __broadcast__(self, from_: _FromT_contra, to: _ToT_contra, /) -> _EquivT_co: ...
6764

6865
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
6966
# e.g. `_HasOwnShape[Shape2N, Shape0N]` accepts `Shape2N`, `Shape1N`, and `Shape0N`, but
@@ -72,6 +69,7 @@ class _CanBroadcastTo(Protocol[_ToT_contra, _RankT_co]):
7269

7370
_OwnShapeT_contra = TypeVar("_OwnShapeT_contra", bound=Shape, default=Any, contravariant=True)
7471
_OwnShapeT_co = TypeVar("_OwnShapeT_co", bound=Shape, default=_OwnShapeT_contra, covariant=True)
72+
_OwnShapeT = TypeVar("_OwnShapeT", bound=tuple[Any, ...], default=Any)
7573

7674
@final
7775
@type_check_only
@@ -81,39 +79,36 @@ class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
8179
###
8280
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
8381

84-
_OwnShapeT = TypeVar("_OwnShapeT", bound=tuple[Any, ...], default=Any)
85-
8682
@type_check_only
87-
class _BaseRank(Generic[_FromT_contra, _ToT_contra, _OwnShapeT]):
88-
def __broadcast_from__(self, from_: _FromT_contra, /) -> Self: ...
89-
def __broadcast_to__(self, to: _ToT_contra, /) -> Self: ...
83+
class _BaseRank(Generic[_FromT_contra, _OwnShapeT, _ToT_contra]):
84+
def __broadcast__(self, from_: _FromT_contra, to: _ToT_contra, /) -> Self: ...
9085
def __own_shape__(self, shape: _OwnShapeT, /) -> _OwnShapeT: ...
9186

9287
@type_check_only
9388
class _BaseRankM(
94-
_BaseRank[_FromT_contra | _HasOwnShape[_ToT_contra, Shape], _ToT_contra, _OwnShapeT],
95-
Generic[_FromT_contra, _ToT_contra, _OwnShapeT],
89+
_BaseRank[_FromT_contra | _HasOwnShape[_ToT_contra, Shape], _OwnShapeT, _ToT_contra],
90+
Generic[_FromT_contra, _OwnShapeT, _ToT_contra],
9691
): ...
9792

9893
@final
9994
@type_check_only
100-
class Rank0(_BaseRankM[_Shape00, Shape0N, Shape0], tuple[()]): ...
95+
class Rank0(_BaseRankM[_Shape00, Shape0, Shape0N], tuple[()]): ...
10196

10297
@final
10398
@type_check_only
104-
class Rank1(_BaseRankM[_Shape01, Shape1N, Shape1], tuple[int]): ...
99+
class Rank1(_BaseRankM[_Shape01, Shape1, Shape1N], tuple[int]): ...
105100

106101
@final
107102
@type_check_only
108-
class Rank2(_BaseRankM[_Shape02, Shape2N, Shape2], tuple[int, int]): ...
103+
class Rank2(_BaseRankM[_Shape02, Shape2, Shape2N], tuple[int, int]): ...
109104

110105
@final
111106
@type_check_only
112-
class Rank3(_BaseRankM[_Shape03, Shape3N, Shape3], tuple[int, int, int]): ...
107+
class Rank3(_BaseRankM[_Shape03, Shape3, Shape3N], tuple[int, int, int]): ...
113108

114109
@final
115110
@type_check_only
116-
class Rank4(_BaseRankM[_Shape04, Shape4N, Shape4], tuple[int, int, int, int]): ...
111+
class Rank4(_BaseRankM[_Shape04, Shape4, Shape4N], tuple[int, int, int, int]): ...
117112

118113
# this emulates `AnyOf`, rather than a `Union`.
119114
@type_check_only

0 commit comments

Comments
 (0)