|
4 | 4 | )
|
5 | 5 |
|
6 | 6 | from types import ModuleType
|
7 |
| -from typing import Literal, Protocol |
| 7 | +from typing import Literal, Protocol, Self |
8 | 8 | from typing_extensions import TypeVar
|
9 | 9 |
|
10 | 10 | NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
|
@@ -76,10 +76,32 @@ def device(self) -> object: # TODO: more specific type
|
76 | 76 | ...
|
77 | 77 |
|
78 | 78 |
|
| 79 | +class HasMatrixTranspose(Protocol): |
| 80 | + """Protocol for array classes that have a matrix transpose attribute.""" |
| 81 | + |
| 82 | + @property |
| 83 | + def mT(self) -> Self: # noqa: N802 |
| 84 | + """Transpose of a matrix (or a stack of matrices). |
| 85 | +
|
| 86 | + If an array instance has fewer than two dimensions, an error should be |
| 87 | + raised. |
| 88 | +
|
| 89 | + Returns: |
| 90 | + Self: array whose last two dimensions (axes) are permuted in reverse |
| 91 | + order relative to original array (i.e., for an array instance |
| 92 | + having shape `(..., M, N)`, the returned array must have shape |
| 93 | + `(..., N, M))`. The returned array must have the same data type |
| 94 | + as the original array. |
| 95 | +
|
| 96 | + """ |
| 97 | + ... |
| 98 | + |
| 99 | + |
79 | 100 | class Array(
|
80 | 101 | # ------ Attributes -------
|
81 | 102 | HasDType[DTypeT_co],
|
82 | 103 | HasDevice,
|
| 104 | + HasMatrixTranspose, |
83 | 105 | # ------- Methods ---------
|
84 | 106 | HasArrayNamespace[NamespaceT_co],
|
85 | 107 | # -------------------------
|
|
0 commit comments