Skip to content

Commit 0a1d1fb

Browse files
committed
♻️ prefer associated-ish over generic types rank types
1 parent 887b85b commit 0a1d1fb

File tree

3 files changed

+96
-39
lines changed

3 files changed

+96
-39
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import assert_type
2+
3+
import _numtype as _nt
4+
5+
a0: _nt.Array0D
6+
assert_type(a0.__inner_shape__, _nt.Rank0)
7+
r0: _nt.Rank0 = a0.shape # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
8+
s0: _nt.Shape0 = a0.shape
9+
10+
a1: _nt.Array1D
11+
assert_type(a1.__inner_shape__, _nt.Rank1)
12+
r1: _nt.Rank1 = a1.shape # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
13+
s1: _nt.Shape1 = a1.shape
14+
15+
a2: _nt.Array2D
16+
assert_type(a2.__inner_shape__, _nt.Rank2)
17+
r2: _nt.Rank2 = a2.shape # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
18+
s2: _nt.Shape2 = a2.shape
19+
20+
a3: _nt.Array3D
21+
assert_type(a3.__inner_shape__, _nt.Rank3)
22+
r3: _nt.Rank3 = a3.shape # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
23+
s3: _nt.Shape3 = a3.shape
24+
25+
a4: _nt.Array4D
26+
assert_type(a4.__inner_shape__, _nt.Rank4)
27+
r4: _nt.Rank4 = a4.shape # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
28+
s4: _nt.Shape4 = a4.shape

src/_numtype/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ from ._nep50 import (
8686
CastsWithScalar as CastsWithScalar,
8787
)
8888
from ._rank import (
89+
HasInnerShape as HasInnerShape,
8990
HasRankGE as HasRankGE,
9091
HasRankLE as HasRankLE,
9192
Rank as Rank,

src/_numtype/_rank.pyi

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Any, Generic, Protocol, Self, TypeAlias, final, type_check_only
2-
from typing_extensions import TypeAliasType, TypeVar
2+
from typing_extensions import TypeAliasType, TypeVar, TypeVarTuple, override
33

44
from ._shape import Shape, Shape0, Shape0N, Shape1, Shape1N, Shape2, Shape2N, Shape3, Shape3N, Shape4, Shape4N
55

66
__all__ = [
7+
"HasInnerShape",
78
"HasRankGE",
89
"HasRankLE",
910
"Rank",
@@ -21,8 +22,7 @@ __all__ = [
2122

2223
###
2324

24-
_Shape00: TypeAlias = Shape0
25-
_Shape01: TypeAlias = _Shape00 | Shape1
25+
_Shape01: TypeAlias = Shape0 | Shape1
2626
_Shape02: TypeAlias = _Shape01 | Shape2
2727
_Shape03: TypeAlias = _Shape02 | Shape3
2828
_Shape04: TypeAlias = _Shape03 | Shape4
@@ -33,44 +33,57 @@ _UpperT = TypeVar("_UpperT", bound=Shape)
3333
_LowerT = TypeVar("_LowerT", bound=Shape)
3434
_RankT = TypeVar("_RankT", bound=Shape, default=Any)
3535

36+
_RankLE: TypeAlias = _CanBroadcast[Any, _UpperT, _RankT]
37+
_RankGE: TypeAlias = _CanBroadcast[_LowerT, Any, _RankT]
38+
3639
HasRankLE = TypeAliasType(
3740
"HasRankLE",
38-
_HasShape[Shape0 | _HasOwnShape[_UpperT] | _CanBroadcast[Any, _UpperT, _RankT]],
41+
_HasInnerShape[Shape0 | _RankLE[_UpperT, _RankT]],
3942
type_params=(_UpperT, _RankT),
4043
)
4144
HasRankGE = TypeAliasType(
4245
"HasRankGE",
43-
_HasShape[_LowerT | _CanBroadcast[_LowerT, Any, _RankT]],
46+
_HasInnerShape[_LowerT | _RankGE[_LowerT, _RankT]],
4447
type_params=(_LowerT, _RankT),
4548
)
4649

47-
###
50+
_ShapeT = TypeVar("_ShapeT", bound=Shape)
4851

49-
_ShapeT_co = TypeVar("_ShapeT_co", bound=Shape | _HasOwnShape | _CanBroadcast, covariant=True)
52+
# for unwrapping potential rank types as shape tuples
53+
HasInnerShape = TypeAliasType(
54+
"HasInnerShape",
55+
_HasInnerShape[_HasOwnShape[Any, _ShapeT]],
56+
type_params=(_ShapeT,),
57+
)
5058

51-
@type_check_only
52-
class _HasShape(Protocol[_ShapeT_co]):
53-
@property
54-
def shape(self, /) -> _ShapeT_co: ...
59+
###
60+
61+
_ShapeLikeT_co = TypeVar("_ShapeLikeT_co", bound=Shape | _HasOwnShape | _CanBroadcast[Any, Any], covariant=True)
5562

56-
_FromT_contra = TypeVar("_FromT_contra", default=Any, contravariant=True)
57-
_ToT_contra = TypeVar("_ToT_contra", bound=Shape, default=Any, contravariant=True)
63+
_FromT_contra = TypeVar("_FromT_contra", contravariant=True)
64+
_ToT_contra = TypeVar("_ToT_contra", bound=tuple[Any, ...], contravariant=True)
5865
_EquivT_co = TypeVar("_EquivT_co", bound=Shape, default=Any, covariant=True)
5966

67+
# __broadcast__ is the type-check-only interface order of ranks
6068
@final
6169
@type_check_only
6270
class _CanBroadcast(Protocol[_FromT_contra, _ToT_contra, _EquivT_co]):
6371
def __broadcast__(self, from_: _FromT_contra, to: _ToT_contra, /) -> _EquivT_co: ...
6472

73+
# __inner_shape__ is similar to `shape`, but directly exposes the `Rank` type.
74+
@final
75+
@type_check_only
76+
class _HasInnerShape(Protocol[_ShapeLikeT_co]):
77+
@property
78+
def __inner_shape__(self, /) -> _ShapeLikeT_co: ...
79+
80+
_OwnShapeT_contra = TypeVar("_OwnShapeT_contra", bound=tuple[Any, ...], default=Any, contravariant=True)
81+
_OwnShapeT_co = TypeVar("_OwnShapeT_co", bound=Shape, default=_OwnShapeT_contra, covariant=True)
82+
6583
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
6684
# e.g. `_HasOwnShape[Shape2N, Shape0N]` accepts `Shape2N`, `Shape1N`, and `Shape0N`, but
6785
# rejects `Shape3N` and `Shape1`. Besides brevity, it also works around several mypy bugs that
6886
# are related to "unions vs joins".
69-
70-
_OwnShapeT_contra = TypeVar("_OwnShapeT_contra", bound=Shape, default=Any, contravariant=True)
71-
_OwnShapeT_co = TypeVar("_OwnShapeT_co", bound=Shape, default=_OwnShapeT_contra, covariant=True)
72-
_OwnShapeT = TypeVar("_OwnShapeT", bound=tuple[Any, ...], default=Any)
73-
7487
@final
7588
@type_check_only
7689
class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
@@ -79,59 +92,74 @@ class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
7992
###
8093
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
8194

82-
@type_check_only
83-
class _BaseRank(Generic[_FromT_contra, _OwnShapeT, _ToT_contra]):
84-
def __broadcast__(self, from_: _FromT_contra, to: _ToT_contra, /) -> Self: ...
85-
def __own_shape__(self, shape: _OwnShapeT, /) -> _OwnShapeT: ...
95+
_Ts = TypeVarTuple("_Ts") # should only contain `int`s
8696

97+
# https://github.com/python/mypy/issues/19093
8798
@type_check_only
88-
class _BaseRankM(
89-
_BaseRank[_FromT_contra | _HasOwnShape[_ToT_contra, Shape], _OwnShapeT, _ToT_contra],
90-
Generic[_FromT_contra, _OwnShapeT, _ToT_contra],
91-
): ...
99+
class BaseRank(tuple[*_Ts], Generic[*_Ts]):
100+
def __broadcast__(self, from_: tuple[*_Ts], to: tuple[*_Ts], /) -> Self: ...
101+
def __own_shape__(self, shape: tuple[*_Ts], /) -> tuple[*_Ts]: ...
92102

93103
@final
94104
@type_check_only
95-
class Rank0(_BaseRankM[_Shape00, Shape0, Shape0N], tuple[()]): ...
105+
class Rank0(BaseRank[*tuple[()]]):
106+
@override
107+
def __broadcast__(self, from_: Shape0 | _HasOwnShape[Shape, Any], to: Shape, /) -> Self: ...
96108

97109
@final
98110
@type_check_only
99-
class Rank1(_BaseRankM[_Shape01, Shape1, Shape1N], tuple[int]): ...
111+
class Rank1(BaseRank[int]):
112+
@override
113+
def __broadcast__(self, from_: _Shape01 | _HasOwnShape[Shape1N, Any], to: Shape1N, /) -> Self: ...
100114

101115
@final
102116
@type_check_only
103-
class Rank2(_BaseRankM[_Shape02, Shape2, Shape2N], tuple[int, int]): ...
117+
class Rank2(BaseRank[int, int]):
118+
@override
119+
def __broadcast__(self, from_: _Shape02 | _HasOwnShape[Shape2N, Any], to: Shape2N, /) -> Self: ...
104120

105121
@final
106122
@type_check_only
107-
class Rank3(_BaseRankM[_Shape03, Shape3, Shape3N], tuple[int, int, int]): ...
123+
class Rank3(BaseRank[int, int, int]):
124+
@override
125+
def __broadcast__(self, from_: _Shape03 | _HasOwnShape[Shape3N, Any], to: Shape3N, /) -> Self: ...
108126

109127
@final
110128
@type_check_only
111-
class Rank4(_BaseRankM[_Shape04, Shape4, Shape4N], tuple[int, int, int, int]): ...
129+
class Rank4(BaseRank[int, int, int, int]):
130+
@override
131+
def __broadcast__(self, from_: _Shape04 | _HasOwnShape[Shape4N, Any], to: Shape4N, /) -> Self: ...
112132

113-
# this emulates `AnyOf`, rather than a `Union`.
114-
@type_check_only
115-
class _BaseRankMToN(_BaseRank[Shape0N, _OwnShapeT, _OwnShapeT], Generic[_OwnShapeT]): ...
133+
# these emulates `AnyOf` (gradual union), rather than a `Union`.
116134

117135
@final
118136
@type_check_only
119-
class Rank(_BaseRankMToN[Shape0N], tuple[int, ...]): ...
137+
class Rank(BaseRank[*tuple[int, ...]]):
138+
@override
139+
def __broadcast__(self, from_: Shape0N, to: tuple[*_Ts], /) -> Self: ...
120140

121141
@final
122142
@type_check_only
123-
class Rank1N(_BaseRankMToN[Shape1N], tuple[int, *tuple[int, ...]]): ...
143+
class Rank1N(BaseRank[int, *tuple[int, ...]]):
144+
@override
145+
def __broadcast__(self, from_: Shape0N, to: Shape1N, /) -> Self: ...
124146

125147
@final
126148
@type_check_only
127-
class Rank2N(_BaseRankMToN[Shape2N], tuple[int, int, *tuple[int, ...]]): ...
149+
class Rank2N(BaseRank[int, int, *tuple[int, ...]]):
150+
@override
151+
def __broadcast__(self, from_: Shape0N, to: Shape2N, /) -> Self: ...
128152

129153
@final
130154
@type_check_only
131-
class Rank3N(_BaseRankMToN[Shape3N], tuple[int, int, int, *tuple[int, ...]]): ...
155+
class Rank3N(BaseRank[int, int, int, *tuple[int, ...]]):
156+
@override
157+
def __broadcast__(self, from_: Shape0N, to: Shape3N, /) -> Self: ...
132158

133159
@final
134160
@type_check_only
135-
class Rank4N(_BaseRankMToN[Shape4N], tuple[int, int, int, int, *tuple[int, ...]]): ...
161+
class Rank4N(BaseRank[int, int, int, int, *tuple[int, ...]]):
162+
@override
163+
def __broadcast__(self, from_: Shape0N, to: Shape4N, /) -> Self: ...
136164

137165
Rank0N: TypeAlias = Rank

0 commit comments

Comments
 (0)