Skip to content

Commit c1391c0

Browse files
authored
Expose the linalg namespace and include in status page (#581)
* Add matmul, matrix_transpose, tensordot, vecdot to linalg namespace * Move outer to linalg namespace * Remove flip from list of unimplemented functions since it was added in #114 * Remove unstack from list of unimplemented functions since it was added in #575 * Add link to cumulative_sum PR * Add linalg table to status page
1 parent d5b40b3 commit c1391c0

File tree

7 files changed

+50
-14
lines changed

7 files changed

+50
-14
lines changed

api_status.md

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Array API Coverage Implementation Status
22

3-
Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [linear algebra extensions](https://data-apis.org/array-api/2022.12/extensions/linear_algebra_functions.html) and [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.
3+
Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.
44

55
Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).
66

@@ -67,7 +67,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
6767
| | `squeeze` | :white_check_mark: | | |
6868
| | `stack` | :white_check_mark: | | |
6969
| | `tile` | :x: | 2023.12 | |
70-
| | `unstack` | :x: | 2023.12 | |
70+
| | `unstack` | :white_check_mark: | 2023.12 | |
7171
| Searching Functions | `argmax` | :white_check_mark: | | |
7272
| | `argmin` | :white_check_mark: | | |
7373
| | `nonzero` | :x: | | Shape is data dependent |
@@ -79,7 +79,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
7979
| | `unique_values` | :x: | | Shape is data dependent |
8080
| Sorting Functions | `argsort` | :x: | | Not in Dask |
8181
| | `sort` | :x: | | Not in Dask |
82-
| Statistical Functions | `cumulative_sum` | :x: | 2023.12 | |
82+
| Statistical Functions | `cumulative_sum` | :x: | 2023.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) |
8383
| | `max` | :white_check_mark: | | |
8484
| | `mean` | :white_check_mark: | | |
8585
| | `min` | :white_check_mark: | | |
@@ -89,3 +89,33 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
8989
| | `var` | :x: | | Like `mean`, [#29](https://github.com/cubed-dev/cubed/issues/29) |
9090
| Utility Functions | `all` | :white_check_mark: | | |
9191
| | `any` | :white_check_mark: | | |
92+
93+
### Linear Algebra Extension
94+
95+
A few of the [linear algebra extension](https://data-apis.org/array-api/2022.12/extensions/linear_algebra_functions.html) functions are supported, as indicated in this table.
96+
97+
| Category | Object/Function | Implemented | Version | Notes |
98+
| ------------------------ | ------------------- | ------------------ | ---------- | ---------------------------- |
99+
| Linear Algebra Functions | `cholesky` | :x: | | |
100+
| | `cross` | :x: | | |
101+
| | `det` | :x: | | |
102+
| | `diagonal` | :x: | | |
103+
| | `eigh` | :x: | | |
104+
| | `eigvalsh` | :x: | | |
105+
| | `inv` | :x: | | |
106+
| | `matmul` | :white_check_mark: | | |
107+
| | `matrix_norm` | :x: | | |
108+
| | `matrix_power` | :x: | | |
109+
| | `matrix_rank` | :x: | | |
110+
| | `matrix_transpose` | :white_check_mark: | | |
111+
| | `outer` | :white_check_mark: | | |
112+
| | `pinv` | :x: | | |
113+
| | `qr` | :white_check_mark: | | |
114+
| | `slogdet` | :x: | | |
115+
| | `solve` | :x: | | |
116+
| | `svd` | :x: | | |
117+
| | `svdvals` | :x: | | |
118+
| | `tensordot` | :white_check_mark: | | |
119+
| | `trace` | :x: | | |
120+
| | `vecdot` | :white_check_mark: | | |
121+
| | `vectornorm` | :x: | | |

cubed/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,11 @@
267267
from .array_api.linear_algebra_functions import (
268268
matmul,
269269
matrix_transpose,
270-
outer,
271270
tensordot,
272271
vecdot,
273272
)
274273

275-
__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"]
274+
__all__ += ["matmul", "matrix_transpose", "tensordot", "vecdot"]
276275

277276
from .array_api.manipulation_functions import (
278277
broadcast_arrays,

cubed/array_api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@
212212

213213
__all__ += ["take"]
214214

215-
from .linear_algebra_functions import matmul, matrix_transpose, outer, tensordot, vecdot
215+
from .linear_algebra_functions import matmul, matrix_transpose, tensordot, vecdot
216216

217-
__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"]
217+
__all__ += ["matmul", "matrix_transpose", "tensordot", "vecdot"]
218218

219219
from .manipulation_functions import (
220220
broadcast_arrays,

cubed/array_api/linalg.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
from typing import NamedTuple
22

33
from cubed.array_api.array_object import Array
4+
5+
# These functions are in both the main and linalg namespaces
6+
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
7+
matmul,
8+
matrix_transpose,
9+
tensordot,
10+
vecdot,
11+
)
412
from cubed.backend_array_api import namespace as nxp
5-
from cubed.core.ops import general_blockwise, map_direct, merge_chunks
13+
from cubed.core.ops import blockwise, general_blockwise, map_direct, merge_chunks
614
from cubed.utils import array_memory, get_item
715

816

17+
def outer(x1, x2, /):
18+
return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype)
19+
20+
921
class QRResult(NamedTuple):
1022
Q: Array
1123
R: Array

cubed/array_api/linear_algebra_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,6 @@ def matrix_transpose(x, /):
9595
return permute_dims(x, axes)
9696

9797

98-
def outer(x1, x2, /):
99-
return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype)
100-
101-
10298
def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
10399
from cubed.array_api.statistical_functions import sum
104100

cubed/tests/test_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def test_matmul_modal(modal_executor):
417417
def test_outer(spec, executor):
418418
a = xp.asarray([0, 1, 2], chunks=2, spec=spec)
419419
b = xp.asarray([10, 50, 100], chunks=2, spec=spec)
420-
c = xp.outer(a, b)
420+
c = xp.linalg.outer(a, b)
421421
assert_array_equal(c.compute(executor=executor), np.outer([0, 1, 2], [10, 50, 100]))
422422

423423

docs/array-api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ The following parts of the standard are not implemented:
1515
| Array object | In-place Ops |
1616
| Creation Functions | `from_dlpack` |
1717
| Indexing | Boolean array |
18-
| Manipulation Functions | `flip` |
1918
| Searching Functions | `nonzero` |
2019
| Set Functions | `unique_all` |
2120
| | `unique_counts` |

0 commit comments

Comments
 (0)