Skip to content

Commit 1f9799a

Browse files
committed
finish rename and fix linalg tests
1 parent df69086 commit 1f9799a

File tree

8 files changed

+70
-64
lines changed

8 files changed

+70
-64
lines changed

.github/workflows/array-api-tests-dask.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ jobs:
66
array-api-tests-dask:
77
uses: ./.github/workflows/array-api-tests.yml
88
with:
9-
package-name: dask
9+
package-name: dask.array
1010
extra-requires: numpy
1111
pytest-extra-args: --disable-deadline --max-examples=5

array_api_compat/common/_linalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def matrix_rank(x: ndarray,
7474
# dimensional arrays.
7575
if x.ndim < 2:
7676
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
77-
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
77+
S = xp.linalg.svdvals(x, **kwargs)
78+
#S = xp.linalg.svd(x, compute_uv=False, **kwargs)
7879
if rtol is None:
7980
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
8081
else:

array_api_compat/dask/array/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
from ._aliases import *
55

66
__array_api_version__ = '2022.12'
7+
8+
__import__(__package__ + '.linalg')

array_api_compat/dask/array/_aliases.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
# This arange func is modified from the common one to
4949
# not pass stop/step as keyword arguments, which will cause
5050
# an error with dask
51+
52+
# TODO: delete the xp stuff, it shouldn't be necessary
5153
def dask_arange(
5254
start: Union[int, float],
5355
/,
@@ -118,3 +120,4 @@ def dask_arange(
118120
concatenate as concat,
119121
)
120122

123+
del da

array_api_compat/dask/array/linalg.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,42 @@
1-
from dask.array.linalg import *
2-
from dask.array.linalg import __all__ as linalg_all
1+
from __future__ import annotations
32

3+
from dask.array.linalg import *
44
from ...common import _linalg
55
from ..._internal import get_xp
6-
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)
6+
from dask.array import matmul, tensordot, trace, outer
7+
from ._aliases import matrix_transpose, vecdot
78

89
import dask.array as da
910

10-
cross = get_xp(da)(_linalg.cross)
11-
outer = get_xp(da)(_linalg.outer)
11+
from typing import TYPE_CHECKING
12+
if TYPE_CHECKING:
13+
from typing import Optional, Union, Tuple
14+
from ...common._typing import ndarray, Device, Dtype
15+
16+
#cross = get_xp(da)(_linalg.cross)
17+
#outer = get_xp(da)(_linalg.outer)
1218
EighResult = _linalg.EighResult
1319
QRResult = _linalg.QRResult
1420
SlogdetResult = _linalg.SlogdetResult
1521
SVDResult = _linalg.SVDResult
16-
eigh = get_xp(da)(_linalg.eigh)
1722
qr = get_xp(da)(_linalg.qr)
18-
slogdet = get_xp(da)(_linalg.slogdet)
19-
svd = get_xp(da)(_linalg.svd)
23+
#svd = get_xp(da)(_linalg.svd)
2024
cholesky = get_xp(da)(_linalg.cholesky)
2125
matrix_rank = get_xp(da)(_linalg.matrix_rank)
22-
pinv = get_xp(da)(_linalg.pinv)
26+
#pinv = get_xp(da)(_linalg.pinv)
2327
matrix_norm = get_xp(da)(_linalg.matrix_norm)
24-
svdvals = get_xp(da)(_linalg.svdvals)
28+
29+
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
30+
# TODO: can't avoid computing U or V for dask
31+
_, s, _ = svd(x)
32+
return s
33+
2534
vector_norm = get_xp(da)(_linalg.vector_norm)
2635
diagonal = get_xp(da)(_linalg.diagonal)
27-
trace = get_xp(da)(_linalg.trace)
2836

29-
__all__ = linalg_all + _linalg.__all__
37+
#__all__ = linalg_all + _linalg.__all__
3038

3139
del get_xp
3240
del da
33-
del linalg_all
41+
#del linalg_all
3442
del _linalg

dask-skips.txt

Lines changed: 0 additions & 2 deletions
This file was deleted.

dask.array-skips.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# FFT isn't conformant
2+
array_api_tests/test_fft.py
3+
4+
# slow and not implemented in dask
5+
array_api_tests/test_linalg.py::test_matrix_power

dask-xfails.txt renamed to dask.array-xfails.txt

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -71,65 +71,51 @@ array_api_tests/test_set_functions.py::test_unique_inverse
7171
array_api_tests/test_set_functions.py::test_unique_values
7272

7373
# Linalg failures (signature failures/missing methods)
74-
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
75-
array_api_tests/test_has_names.py::test_has_names[linalg-det]
76-
array_api_tests/test_has_names.py::test_has_names[linalg-diagonal]
77-
array_api_tests/test_has_names.py::test_has_names[linalg-eigh]
78-
array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh]
79-
array_api_tests/test_has_names.py::test_has_names[linalg-matmul]
80-
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_norm]
81-
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
82-
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_rank]
83-
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose]
84-
array_api_tests/test_has_names.py::test_has_names[linalg-outer]
85-
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
86-
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]
87-
array_api_tests/test_has_names.py::test_has_names[linalg-svdvals]
88-
array_api_tests/test_has_names.py::test_has_names[linalg-tensordot]
89-
array_api_tests/test_has_names.py::test_has_names[linalg-trace]
90-
array_api_tests/test_has_names.py::test_has_names[linalg-vecdot]
91-
array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm]
92-
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
93-
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
94-
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
95-
array_api_tests/test_linalg.py::test_cross
96-
array_api_tests/test_linalg.py::test_det
97-
array_api_tests/test_linalg.py::test_diagonal
98-
array_api_tests/test_linalg.py::test_eigvalsh
99-
array_api_tests/test_linalg.py::test_matrix_norm
100-
array_api_tests/test_linalg.py::test_matrix_rank
101-
array_api_tests/test_linalg.py::test_outer
102-
array_api_tests/test_linalg.py::test_pinv
103-
array_api_tests/test_linalg.py::test_slogdet
74+
75+
76+
# fails for ndim > 2
10477
array_api_tests/test_linalg.py::test_svdvals
78+
array_api_tests/test_linalg.py::test_cholesky
79+
# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
10580
array_api_tests/test_linalg.py::test_tensordot
81+
# probably same reason for failing as numpy
10682
array_api_tests/test_linalg.py::test_trace
107-
array_api_tests/test_linalg.py::test_cholesky
108-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cholesky]
83+
84+
# Linalg - these don't exist in dask
10985
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross]
11086
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det]
111-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.diagonal]
11287
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh]
11388
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh]
114-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matmul]
115-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_norm]
11689
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power]
117-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_rank]
118-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_transpose]
119-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.outer]
12090
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv]
121-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.qr]
12291
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet]
123-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd]
124-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svdvals]
125-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.tensordot]
126-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.trace]
127-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
128-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vector_norm]
129-
# errors
130-
array_api_tests/test_linalg.py::test_matrix_power
92+
array_api_tests/test_linalg.py::test_cross
93+
array_api_tests/test_linalg.py::test_det
94+
array_api_tests/test_linalg.py::test_eigvalsh
95+
array_api_tests/test_linalg.py::test_pinv
96+
array_api_tests/test_linalg.py::test_slogdet
97+
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
98+
array_api_tests/test_has_names.py::test_has_names[linalg-det]
99+
array_api_tests/test_has_names.py::test_has_names[linalg-eigh]
100+
array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh]
101+
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
102+
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
103+
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]
104+
105+
array_api_tests/test_linalg.py::test_matrix_norm
106+
array_api_tests/test_linalg.py::test_matrix_rank
107+
108+
# missing mode kw
109+
# https://github.com/dask/dask/issues/10388
131110
array_api_tests/test_linalg.py::test_qr
111+
112+
# Constructing the input arrays fails to a weird shape error...
132113
array_api_tests/test_linalg.py::test_solve
114+
115+
# missing full_matrics kw
116+
# https://github.com/dask/dask/issues/10389
117+
# also only supports 2-d inputs
118+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd]
133119
array_api_tests/test_linalg.py::test_svd
134120

135121
# Missing dlpack stuff
@@ -138,6 +124,9 @@ array_api_tests/test_signatures.py::test_array_method_signature[__array_namespac
138124
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
139125
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
140126
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
127+
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
128+
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
129+
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
141130

142131
# Some cases unsupported by dask
143132
array_api_tests/test_manipulation_functions.py::test_roll

0 commit comments

Comments
 (0)