Skip to content

Commit d68ae05

Browse files
committed
✨ HasShape
1 parent 687881b commit d68ae05

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"HasDType",
77
"HasMatrixTranspose",
88
"HasNDim",
9+
"HasShape",
910
"__version__",
1011
"__version_tuple__",
1112
)
@@ -16,5 +17,6 @@
1617
HasDType,
1718
HasMatrixTranspose,
1819
HasNDim,
20+
HasShape,
1921
)
2022
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"HasDevice",
66
"HasMatrixTranspose",
77
"HasNDim",
8+
"HasShape",
89
)
910

1011
from types import ModuleType
@@ -115,12 +116,34 @@ def ndim(self) -> int:
115116
...
116117

117118

119+
class HasShape(Protocol):
120+
"""Protocol for array classes that have a shape attribute."""
121+
122+
@property
123+
def shape(self) -> tuple[int | None, ...]:
124+
"""Shape of the array.
125+
126+
Returns:
127+
tuple[int | None, ...]: array dimensions. An array dimension must be None
128+
if and only if a dimension is unknown.
129+
130+
Notes:
131+
For array libraries having graph-based computational models, array
132+
dimensions may be unknown due to data-dependent operations (e.g.,
133+
boolean indexing; `A[:, B > 0]`) and thus cannot be statically
134+
resolved without knowing array contents.
135+
136+
"""
137+
...
138+
139+
118140
class Array(
119141
# ------ Attributes -------
120142
HasDType[DTypeT_co],
121143
HasDevice,
122144
HasMatrixTranspose,
123145
HasNDim,
146+
HasShape,
124147
# ------- Methods ---------
125148
HasArrayNamespace[NamespaceT_co],
126149
# -------------------------

tests/integration/test_numpy1p0.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,7 @@ assert_type(x_i32.mT, xpt.Array[dtype[Any]])
6767
# Check Attribute `.ndim`
6868
assert_type(x_f32.ndim, int)
6969
assert_type(x_i32.ndim, int)
70+
71+
# Check Attribute `.shape`
72+
assert_type(x_f32.shape, tuple[int | None, ...])
73+
assert_type(x_i32.shape, tuple[int | None, ...])

tests/integration/test_numpy2p0.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,8 @@ assert_type(x_b.mT, xpt.Array[np.dtype[B]])
7676
assert_type(x_f32.ndim, int)
7777
assert_type(x_i32.ndim, int)
7878
assert_type(x_b.ndim, int)
79+
80+
# Check Attribute `.shape`
81+
assert_type(x_f32.shape, tuple[int | None, ...])
82+
assert_type(x_i32.shape, tuple[int | None, ...])
83+
assert_type(x_b.shape, tuple[int | None, ...])

0 commit comments

Comments
 (0)