Skip to content

Commit ba2f76b

Browse files
committed
✨ HasMatrixTranspose
1 parent 3741de7 commit ba2f76b

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/array_api_typing/_array.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
)
55

66
from types import ModuleType
7-
from typing import Literal, Protocol
7+
from typing import Literal, Protocol, Self
88
from typing_extensions import TypeVar
99

1010
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
@@ -76,10 +76,32 @@ def device(self) -> object: # TODO: more specific type
7676
...
7777

7878

79+
class HasMatrixTranspose(Protocol):
80+
"""Protocol for array classes that have a matrix transpose attribute."""
81+
82+
@property
83+
def mT(self) -> Self: # noqa: N802
84+
"""Transpose of a matrix (or a stack of matrices).
85+
86+
If an array instance has fewer than two dimensions, an error should be
87+
raised.
88+
89+
Returns:
90+
Self: array whose last two dimensions (axes) are permuted in reverse
91+
order relative to original array (i.e., for an array instance
92+
having shape `(..., M, N)`, the returned array must have shape
93+
`(..., N, M))`. The returned array must have the same data type
94+
as the original array.
95+
96+
"""
97+
...
98+
99+
79100
class Array(
80101
# ------ Attributes -------
81102
HasDType[DTypeT_co],
82103
HasDevice,
104+
HasMatrixTranspose,
83105
# ------- Methods ---------
84106
HasArrayNamespace[NamespaceT_co],
85107
# -------------------------

0 commit comments

Comments
 (0)