Skip to content

Commit 0126331

Browse files
committed
✨ HasMatrixTranspose
1 parent fbbc9e1 commit 0126331

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"HasArrayNamespace",
66
"HasDType",
77
"HasDevice",
8+
"HasMatrixTranspose",
89
"__version__",
910
"__version_tuple__",
1011
)
@@ -14,5 +15,6 @@
1415
HasArrayNamespace,
1516
HasDevice,
1617
HasDType,
18+
HasMatrixTranspose,
1719
)
1820
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
"HasArrayNamespace",
44
"HasDType",
55
"HasDevice",
6+
"HasMatrixTranspose",
67
)
78

89
from types import ModuleType
9-
from typing import Literal, Protocol
10+
from typing import Literal, Protocol, Self
1011
from typing_extensions import TypeVar
1112

1213
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
@@ -78,6 +79,27 @@ def device(self) -> object: # TODO: more specific type
7879
...
7980

8081

82+
class HasMatrixTranspose(Protocol):
83+
"""Protocol for array classes that have a matrix transpose attribute."""
84+
85+
@property
86+
def mT(self) -> Self: # noqa: N802
87+
"""Transpose of a matrix (or a stack of matrices).
88+
89+
If an array instance has fewer than two dimensions, an error should be
90+
raised.
91+
92+
Returns:
93+
Self: array whose last two dimensions (axes) are permuted in reverse
94+
order relative to original array (i.e., for an array instance
95+
having shape `(..., M, N)`, the returned array must have shape
96+
`(..., N, M))`. The returned array must have the same data type
97+
as the original array.
98+
99+
"""
100+
...
101+
102+
81103
class Array(
82104
# ------ Attributes -------
83105
HasDType[DTypeT_co],

tests/integration/test_numpy1p0.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@ x_i32: xpt.Array[dtype[Any]] = nparr_i32
6262
# Check Attribute `.dtype`
6363
assert_type(x_f32.dtype, dtype[Any])
6464
assert_type(x_i32.dtype, dtype[Any])
65+
66+
# Check Attribute `.mT`
67+
assert_type(x_f32.mT, xpt.Array[dtype[Any]])
68+
assert_type(x_i32.mT, xpt.Array[dtype[Any]])

tests/integration/test_numpy2p0.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,8 @@ x_b: xpt.Array[np.dtype[B]] = nparr_b
6969
assert_type(x_f32.dtype, np.dtype[F32])
7070
assert_type(x_i32.dtype, np.dtype[I32])
7171
assert_type(x_b.dtype, np.dtype[B])
72+
73+
# Check Attribute `.mT`
74+
assert_type(x_f32.mT, xpt.Array[np.dtype[F32]])
75+
assert_type(x_i32.mT, xpt.Array[np.dtype[I32]])
76+
assert_type(x_b.mT, xpt.Array[np.dtype[B]])

0 commit comments

Comments
 (0)