Skip to content

Commit 4c3963b

Browse files
committed
✨ HasMatrixTranspose
Signed-off-by: nstarman <[email protected]>
1 parent d786599 commit 4c3963b

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"HasArrayNamespace",
66
"HasDType",
77
"HasDevice",
8+
"HasMatrixTranspose",
89
"__version__",
910
"__version_tuple__",
1011
)
@@ -14,5 +15,6 @@
1415
HasArrayNamespace,
1516
HasDevice,
1617
HasDType,
18+
HasMatrixTranspose,
1719
)
1820
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
"HasArrayNamespace",
44
"HasDType",
55
"HasDevice",
6+
"HasMatrixTranspose",
67
)
78

89
from types import ModuleType
9-
from typing import Literal, Protocol
10+
from typing import Literal, Protocol, Self
1011
from typing_extensions import TypeVar
1112

1213
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
@@ -79,6 +80,27 @@ def device(self) -> DeviceT_co:
7980
...
8081

8182

83+
class HasMatrixTranspose(Protocol):
84+
"""Protocol for array classes that have a matrix transpose attribute."""
85+
86+
@property
87+
def mT(self) -> Self: # noqa: N802
88+
"""Transpose of a matrix (or a stack of matrices).
89+
90+
If an array instance has fewer than two dimensions, an error should be
91+
raised.
92+
93+
Returns:
94+
Self: array whose last two dimensions (axes) are permuted in reverse
95+
order relative to original array (i.e., for an array instance
96+
having shape `(..., M, N)`, the returned array must have shape
97+
`(..., N, M))`. The returned array must have the same data type
98+
as the original array.
99+
100+
"""
101+
...
102+
103+
82104
class Array(
83105
# ------ Attributes -------
84106
HasDType[DTypeT_co],

tests/integration/test_numpy1p0.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ _: xpt.HasDevice = nparr
4646
_: xpt.HasDevice = nparr_i32
4747
_: xpt.HasDevice = nparr_f32
4848

49+
# =========================================================
50+
# `xpt.HasMatrixTranspose`
51+
52+
_: xpt.HasMatrixTranspose = nparr
53+
_: xpt.HasMatrixTranspose = nparr_i32
54+
_: xpt.HasMatrixTranspose = nparr_f32
55+
4956
# =========================================================
5057
# `xpt.Array`
5158

tests/integration/test_numpy2p0.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ _: xpt.HasDevice = nparr_i32
5353
_: xpt.HasDevice = nparr_f32
5454
_: xpt.HasDevice = nparr_b
5555

56+
# =========================================================
57+
# `xpt.HasMatrixTranspose`
58+
59+
_: xpt.HasMatrixTranspose = nparr
60+
_: xpt.HasMatrixTranspose = nparr_i32
61+
_: xpt.HasMatrixTranspose = nparr_f32
62+
_: xpt.HasMatrixTranspose = nparr_b
63+
5664
# =========================================================
5765
# `xpt.Array`
5866

0 commit comments

Comments
 (0)