Skip to content

Commit 25db39c

Browse files
committed
✨ HasShape
1 parent 58698bb commit 25db39c

File tree

4 files changed

+39
-0
lines changed

4 files changed

+39
-0
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"HasDevice",
88
"HasMatrixTranspose",
99
"HasNDim",
10+
"HasShape",
1011
"__version__",
1112
"__version_tuple__",
1213
)
@@ -18,5 +19,6 @@
1819
HasDType,
1920
HasMatrixTranspose,
2021
HasNDim,
22+
HasShape,
2123
)
2224
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"HasDevice",
66
"HasMatrixTranspose",
77
"HasNDim",
8+
"HasShape",
89
)
910

1011
from types import ModuleType
@@ -116,6 +117,27 @@ def ndim(self) -> int:
116117
...
117118

118119

120+
class HasShape(Protocol):
121+
"""Protocol for array classes that have a shape attribute."""
122+
123+
@property
124+
def shape(self) -> tuple[int | None, ...]:
125+
"""Shape of the array.
126+
127+
Returns:
128+
tuple[int | None, ...]: array dimensions. An array dimension must be None
129+
if and only if a dimension is unknown.
130+
131+
Notes:
132+
For array libraries having graph-based computational models, array
133+
dimensions may be unknown due to data-dependent operations (e.g.,
134+
boolean indexing; `A[:, B > 0]`) and thus cannot be statically
135+
resolved without knowing array contents.
136+
137+
"""
138+
...
139+
140+
119141
class Array(
120142
# ------ Attributes -------
121143
HasDType[DTypeT_co],

tests/integration/test_numpy1p0.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ _: xpt.HasNDim = nparr
6060
_: xpt.HasNDim = nparr_i32
6161
_: xpt.HasNDim = nparr_f32
6262

63+
# =========================================================
64+
# `xpt.HasShape`
65+
66+
_: xpt.HasShape = nparr
67+
_: xpt.HasShape = nparr_i32
68+
_: xpt.HasShape = nparr_f32
69+
6370
# =========================================================
6471
# `xpt.Array`
6572

tests/integration/test_numpy2p0.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ _: xpt.HasNDim = nparr_i32
6969
_: xpt.HasNDim = nparr_f32
7070
_: xpt.HasNDim = nparr_b
7171

72+
# =========================================================
73+
# `xpt.HasShape`
74+
75+
_: xpt.HasShape = nparr
76+
_: xpt.HasShape = nparr_i32
77+
_: xpt.HasShape = nparr_f32
78+
_: xpt.HasShape = nparr_b
79+
7280
# =========================================================
7381
# `xpt.Array`
7482

0 commit comments

Comments
 (0)