Skip to content

Commit 589fc89

Browse files
committed
✨ HasMatrixTranspose
1 parent 0cc3a30 commit 589fc89

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

src/array_api_typing/_array.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66

77
from types import ModuleType
8-
from typing import Literal, Protocol
8+
from typing import Literal, Protocol, Self
99
from typing_extensions import TypeVar
1010

1111
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
@@ -77,10 +77,32 @@ def device(self) -> object: # TODO: more specific type
7777
...
7878

7979

80+
class HasMatrixTranspose(Protocol):
81+
"""Protocol for array classes that have a matrix transpose attribute."""
82+
83+
@property
84+
def mT(self) -> Self: # noqa: N802
85+
"""Transpose of a matrix (or a stack of matrices).
86+
87+
If an array instance has fewer than two dimensions, an error should be
88+
raised.
89+
90+
Returns:
91+
Self: array whose last two dimensions (axes) are permuted in reverse
92+
order relative to original array (i.e., for an array instance
93+
having shape `(..., M, N)`, the returned array must have shape
94+
`(..., N, M))`. The returned array must have the same data type
95+
as the original array.
96+
97+
"""
98+
...
99+
100+
80101
class Array(
81102
# ------ Attributes -------
82103
HasDType[DTypeT_co],
83104
HasDevice,
105+
HasMatrixTranspose,
84106
# ------- Methods ---------
85107
HasArrayNamespace[NamespaceT_co],
86108
# -------------------------

tests/integration/test_numpy1p0.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,7 @@ _: dtype[Any] = x_i32.dtype
5959
# Check Attribute `.device`
6060
_: object = x_f32.device
6161
_: object = x_i32.device
62+
63+
# Check Attribute `.mT`
64+
_: xpt.Array[dtype[Any]] = x_f32.mT
65+
_: xpt.Array[dtype[Any]] = x_i32.mT

tests/integration/test_numpy2p0.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,8 @@ _: np.dtype[B] = x_b.dtype
6666
_: object = x_f32.device
6767
_: object = x_i32.device
6868
_: object = x_b.device
69+
70+
# Check Attribute `.mT`
71+
_: xpt.Array[np.dtype[F32]] = x_f32.mT
72+
_: xpt.Array[np.dtype[I32]] = x_i32.mT
73+
_: xpt.Array[np.dtype[B]] = x_b.mT

0 commit comments

Comments
 (0)