Skip to content

Commit 1fa31c6

Browse files
committed
ENH: array types: add dask.array support
1 parent 20ece11 commit 1fa31c6

File tree

10 files changed

+64
-26
lines changed

10 files changed

+64
-26
lines changed

.github/workflows/array_api.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ jobs:
4141
name: Get commit message
4242
uses: ./.github/workflows/commit_message.yml
4343

44-
pytorch_cpu:
45-
name: Linux PyTorch/JAX/xp-strict CPU
44+
xp_cpu:
45+
name: Linux PyTorch/JAX/Dask/xp-strict CPU
4646
needs: get_commit_message
4747
if: >
4848
needs.get_commit_message.outputs.message == 1
@@ -84,6 +84,10 @@ jobs:
8484
run: |
8585
python -m pip install "jax[cpu]"
8686
87+
- name: Install Dask
88+
run: |
89+
python -m pip install git+https://github.com/dask/dask.git
90+
8791
- name: Prepare compiler cache
8892
id: prep-ccache
8993
shell: bash

dev.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,8 @@ class Test(Task):
710710
multiple=True,
711711
help=(
712712
"Array API backend "
713-
"('all', 'numpy', 'torch', 'cupy', 'array_api_strict', 'jax.numpy')."
713+
"('all', 'numpy', 'torch', 'cupy', 'array_api_strict',"
714+
" 'jax.numpy', 'dask.array')."
714715
)
715716
)
716717
# Argument can't have `help=`; used to consume all of `-- arg1 arg2 arg3`

scipy/_lib/_array_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_cupy_namespace as is_cupy,
2727
is_torch_namespace as is_torch,
2828
is_jax_namespace as is_jax,
29+
is_dask_namespace as is_dask,
2930
is_array_api_strict_namespace as is_array_api_strict
3031
)
3132

@@ -250,6 +251,9 @@ def _strict_check(actual, desired, xp, *,
250251
assert actual.dtype == desired.dtype, _msg
251252

252253
if check_shape:
254+
if is_dask(xp):
255+
actual.compute_chunk_sizes()
256+
desired.compute_chunk_sizes()
253257
_msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
254258
assert actual.shape == desired.shape, _msg
255259

scipy/_lib/tests/test_array_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from scipy.conftest import array_api_compatible
55
from scipy._lib._array_api import (
66
_GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy,
7-
xp_create_diagonal
7+
xp_create_diagonal, is_dask,
88
)
99
from scipy._lib._array_api_no_0d import xp_assert_equal as xp_assert_equal_no_0d
1010
import scipy._lib.array_api_compat.numpy as np_compat
@@ -69,6 +69,10 @@ def test_copy(self, xp):
6969
x[1] = 11
7070
x[2] = 12
7171

72+
if is_dask(xp):
73+
x.compute()
74+
y.compute()
75+
7276
assert x[0] != y[0]
7377
assert x[1] != y[1]
7478
assert x[2] != y[2]

scipy/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from scipy._lib._fpumode import get_fpu_mode
1414
from scipy._lib._testutils import FPUModeChangeWarning
15-
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE
15+
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE, xp_device
1616
from scipy._lib import _pep440
1717

1818
try:
@@ -153,6 +153,12 @@ def check_fpu_mode(request):
153153
except ImportError:
154154
pass
155155

156+
try:
157+
import dask.array # type: ignore[import-not-found]
158+
xp_available_backends.update({'dask.array': dask.array})
159+
except ImportError:
160+
pass
161+
156162
# by default, use all available backends
157163
if SCIPY_ARRAY_API.lower() not in ("1", "true"):
158164
SCIPY_ARRAY_API_ = json.loads(SCIPY_ARRAY_API)
@@ -318,6 +324,9 @@ def skip_or_xfail_xp_backends(xp, backends, kwargs, skip_or_xfail='skip'):
318324
for d in xp.empty(0).devices():
319325
if 'cpu' not in d.device_kind:
320326
skip_or_xfail(reason=reason)
327+
elif xp.__name__ == 'dask.array':
328+
if xp_device(xp.empty(0)) != 'cpu':
329+
skip_or_xfail(reason=reason)
321330

322331
if backends is not None:
323332
for i, backend in enumerate(backends):

scipy/ndimage/tests/test_filters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def test_correlate01(self, xp):
191191

192192
@xfail_xp_backends('cupy', reason="Differs by a factor of two?")
193193
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
194+
@skip_xp_backends("dask.array", reason="output array is read-only.")
194195
def test_correlate01_overlap(self, xp):
195196
array = xp.reshape(xp.arange(256), (16, 16))
196197
weights = xp.asarray([2])
@@ -541,6 +542,7 @@ def test_correlate22(self, dtype_array, dtype_output, xp):
541542
assert_array_almost_equal(output, expected)
542543

543544
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
545+
@skip_xp_backends("dask.array", reason="output array is read-only.")
544546
@pytest.mark.parametrize('dtype_array', types)
545547
@pytest.mark.parametrize('dtype_output', types)
546548
def test_correlate23(self, dtype_array, dtype_output, xp):
@@ -560,6 +562,7 @@ def test_correlate23(self, dtype_array, dtype_output, xp):
560562
assert_array_almost_equal(output, expected)
561563

562564
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
565+
@skip_xp_backends("dask.array", reason="output array is read-only.")
563566
@pytest.mark.parametrize('dtype_array', types)
564567
@pytest.mark.parametrize('dtype_output', types)
565568
def test_correlate24(self, dtype_array, dtype_output, xp):
@@ -580,6 +583,7 @@ def test_correlate24(self, dtype_array, dtype_output, xp):
580583
assert_array_almost_equal(output, tcov)
581584

582585
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
586+
@skip_xp_backends("dask.array", reason="output array is read-only.")
583587
@pytest.mark.parametrize('dtype_array', types)
584588
@pytest.mark.parametrize('dtype_output', types)
585589
def test_correlate25(self, dtype_array, dtype_output, xp):
@@ -875,6 +879,7 @@ def test_gauss06(self, xp):
875879
assert_array_almost_equal(output1, output2)
876880

877881
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
882+
@skip_xp_backends("dask.array", reason="output array is read-only.")
878883
def test_gauss_memory_overlap(self, xp):
879884
input = xp.arange(100 * 100, dtype=xp.float32)
880885
input = xp.reshape(input, (100, 100))
@@ -1234,6 +1239,7 @@ def test_prewitt01(self, dtype, xp):
12341239
assert_array_almost_equal(t, output)
12351240

12361241
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1242+
@skip_xp_backends("dask.array", reason="output array is read-only.")
12371243
@pytest.mark.parametrize('dtype', types + complex_types)
12381244
def test_prewitt02(self, dtype, xp):
12391245
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1296,6 +1302,7 @@ def test_sobel01(self, dtype, xp):
12961302
assert_array_almost_equal(t, output)
12971303

12981304
@skip_xp_backends("jax.numpy", reason="output array is read-only.",)
1305+
@skip_xp_backends("dask.array", reason="output array is read-only.")
12991306
@pytest.mark.parametrize('dtype', types + complex_types)
13001307
def test_sobel02(self, dtype, xp):
13011308
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1356,6 +1363,7 @@ def test_laplace01(self, dtype, xp):
13561363
assert_array_almost_equal(tmp1 + tmp2, output)
13571364

13581365
@skip_xp_backends("jax.numpy", reason="output array is read-only",)
1366+
@skip_xp_backends("dask.array", reason="output array is read-only.")
13591367
@pytest.mark.parametrize('dtype',
13601368
["int32", "float32", "float64",
13611369
"complex64", "complex128"])
@@ -1386,6 +1394,7 @@ def test_gaussian_laplace01(self, dtype, xp):
13861394
assert_array_almost_equal(tmp1 + tmp2, output)
13871395

13881396
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1397+
@skip_xp_backends("dask.array", reason="output array is read-only.")
13891398
@pytest.mark.parametrize('dtype',
13901399
["int32", "float32", "float64",
13911400
"complex64", "complex128"])
@@ -1402,6 +1411,7 @@ def test_gaussian_laplace02(self, dtype, xp):
14021411
assert_array_almost_equal(tmp1 + tmp2, output)
14031412

14041413
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1414+
@skip_xp_backends("dask.array", reason="output array is read-only.")
14051415
@pytest.mark.parametrize('dtype', types + complex_types)
14061416
def test_generic_laplace01(self, dtype, xp):
14071417
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1427,6 +1437,7 @@ def derivative2(input, axis, output, mode, cval, a, b):
14271437
assert_array_almost_equal(tmp, output)
14281438

14291439
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1440+
@skip_xp_backends("dask.array", reason="output array is read-only.")
14301441
@pytest.mark.parametrize('dtype',
14311442
["int32", "float32", "float64",
14321443
"complex64", "complex128"])
@@ -1448,6 +1459,7 @@ def test_gaussian_gradient_magnitude01(self, dtype, xp):
14481459
xp_assert_close(output, expected, rtol=1e-6, atol=1e-6)
14491460

14501461
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1462+
@skip_xp_backends("dask.array", reason="output array is read-only.")
14511463
@pytest.mark.parametrize('dtype',
14521464
["int32", "float32", "float64",
14531465
"complex64", "complex128"])
@@ -2659,6 +2671,7 @@ def test_gaussian_radius_invalid(xp):
26592671

26602672

26612673
@skip_xp_backends("jax.numpy", reason="output array is read-only")
2674+
@skip_xp_backends("dask.array", reason="output array is read-only.")
26622675
class TestThreading:
26632676
def check_func_thread(self, n, fun, args, out):
26642677
from threading import Thread

scipy/ndimage/tests/test_morphology.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,6 +2265,7 @@ def test_grey_erosion01(self, xp):
22652265
[5, 5, 3, 3, 1]]))
22662266

22672267
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2268+
@skip_xp_backends("dask.array", reason="output array is read-only.")
22682269
@xfail_xp_backends("cupy", reason="https://github.com/cupy/cupy/issues/8398")
22692270
def test_grey_erosion01_overlap(self, xp):
22702271

@@ -2460,6 +2461,7 @@ def test_morphological_laplace02(self, xp):
24602461
assert_array_almost_equal(expected, output)
24612462

24622463
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2464+
@skip_xp_backends("dask.array", reason="output array is read-only.")
24632465
def test_white_tophat01(self, xp):
24642466
array = xp.asarray([[3, 2, 5, 1, 4],
24652467
[7, 6, 9, 3, 5],
@@ -2513,6 +2515,7 @@ def test_white_tophat03(self, xp):
25132515
xp_assert_equal(expected, output)
25142516

25152517
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2518+
@skip_xp_backends("dask.array", reason="output array is read-only.")
25162519
def test_white_tophat04(self, xp):
25172520
array = np.eye(5, dtype=bool)
25182521
structure = np.ones((3, 3), dtype=bool)
@@ -2525,6 +2528,7 @@ def test_white_tophat04(self, xp):
25252528
ndimage.white_tophat(array, structure=structure, output=output)
25262529

25272530
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2531+
@skip_xp_backends("dask.array", reason="output array is read-only.")
25282532
def test_black_tophat01(self, xp):
25292533
array = xp.asarray([[3, 2, 5, 1, 4],
25302534
[7, 6, 9, 3, 5],
@@ -2578,6 +2582,7 @@ def test_black_tophat03(self, xp):
25782582
xp_assert_equal(expected, output)
25792583

25802584
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2585+
@skip_xp_backends("dask.array", reason="output array is read-only.")
25812586
def test_black_tophat04(self, xp):
25822587
array = xp.asarray(np.eye(5, dtype=bool))
25832588
structure = xp.asarray(np.ones((3, 3), dtype=bool))

scipy/special/tests/test_support_alternative_backends.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from scipy.conftest import array_api_compatible
66
from scipy import special
77
from scipy._lib._array_api_no_0d import xp_assert_close
8-
from scipy._lib._array_api import is_jax, is_torch, SCIPY_DEVICE
8+
from scipy._lib._array_api import is_jax, is_torch, SCIPY_DEVICE, is_dask
99
from scipy._lib.array_api_compat import numpy as np
1010

1111
try:
@@ -64,6 +64,9 @@ def test_support_alternative_backends(xp, f_name_n_args, dtype, shapes):
6464
):
6565
pytest.skip(f"`{f_name}` does not have an array-agnostic implementation "
6666
f"and cannot delegate to PyTorch.")
67+
68+
if is_dask(xp) and f_name == 'rel_entr':
69+
pytest.skip("boolean index assignment")
6770

6871
shapes = shapes[:n_args]
6972
f = getattr(special, f_name)

scipy/stats/tests/test_entropy.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
1212
xp_assert_less)
1313

14+
@pytest.mark.skip_xp_backends("dask.array", reason="boolean index assignment")
15+
@pytest.mark.usefixtures("skip_xp_backends")
16+
@array_api_compatible
1417
class TestEntropy:
15-
@array_api_compatible
18+
1619
def test_entropy_positive(self, xp):
1720
# See ticket #497
1821
pk = xp.asarray([0.5, 0.2, 0.3])
@@ -22,7 +25,6 @@ def test_entropy_positive(self, xp):
2225
xp_assert_equal(eself, xp.asarray(0.))
2326
xp_assert_less(-edouble, xp.asarray(0.))
2427

25-
@array_api_compatible
2628
def test_entropy_base(self, xp):
2729
pk = xp.ones(16)
2830
S = stats.entropy(pk, base=2.)
@@ -34,21 +36,18 @@ def test_entropy_base(self, xp):
3436
S2 = stats.entropy(pk, qk, base=2.)
3537
xp_assert_less(xp.abs(S/S2 - math.log(2.)), xp.asarray(1.e-5))
3638

37-
@array_api_compatible
3839
def test_entropy_zero(self, xp):
3940
# Test for PR-479
4041
x = xp.asarray([0., 1., 2.])
4142
xp_assert_close(stats.entropy(x),
4243
xp.asarray(0.63651416829481278))
4344

44-
@array_api_compatible
4545
def test_entropy_2d(self, xp):
4646
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
4747
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
4848
xp_assert_close(stats.entropy(pk, qk),
4949
xp.asarray([0.1933259, 0.18609809]))
5050

51-
@array_api_compatible
5251
def test_entropy_2d_zero(self, xp):
5352
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
5453
qk = xp.asarray([[0.0, 0.1], [0.3, 0.6], [0.5, 0.3]])
@@ -59,54 +58,46 @@ def test_entropy_2d_zero(self, xp):
5958
xp_assert_close(stats.entropy(pk, qk),
6059
xp.asarray([0.17403988, 0.18609809]))
6160

62-
@array_api_compatible
6361
def test_entropy_base_2d_nondefault_axis(self, xp):
6462
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
6563
xp_assert_close(stats.entropy(pk, axis=1),
6664
xp.asarray([0.63651417, 0.63651417, 0.66156324]))
6765

68-
@array_api_compatible
6966
def test_entropy_2d_nondefault_axis(self, xp):
7067
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
7168
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
7269
xp_assert_close(stats.entropy(pk, qk, axis=1),
7370
xp.asarray([0.23104906, 0.23104906, 0.12770641]))
7471

75-
@array_api_compatible
7672
def test_entropy_raises_value_error(self, xp):
7773
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
7874
qk = xp.asarray([[0.1, 0.2], [0.6, 0.3]])
7975
message = "Array shapes are incompatible for broadcasting."
8076
with pytest.raises(ValueError, match=message):
8177
stats.entropy(pk, qk)
8278

83-
@array_api_compatible
8479
def test_base_entropy_with_axis_0_is_equal_to_default(self, xp):
8580
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
8681
xp_assert_close(stats.entropy(pk, axis=0),
8782
stats.entropy(pk))
8883

89-
@array_api_compatible
9084
def test_entropy_with_axis_0_is_equal_to_default(self, xp):
9185
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
9286
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
9387
xp_assert_close(stats.entropy(pk, qk, axis=0),
9488
stats.entropy(pk, qk))
9589

96-
@array_api_compatible
9790
def test_base_entropy_transposed(self, xp):
9891
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
9992
xp_assert_close(stats.entropy(pk.T),
10093
stats.entropy(pk, axis=1))
10194

102-
@array_api_compatible
10395
def test_entropy_transposed(self, xp):
10496
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
10597
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
10698
xp_assert_close(stats.entropy(pk.T, qk.T),
10799
stats.entropy(pk, qk, axis=1))
108100

109-
@array_api_compatible
110101
def test_entropy_broadcasting(self, xp):
111102
rng = np.random.default_rng(74187315492831452)
112103
x = xp.asarray(rng.random(3))
@@ -115,22 +106,21 @@ def test_entropy_broadcasting(self, xp):
115106
xp_assert_equal(res[0], stats.entropy(x, y[0, ...]))
116107
xp_assert_equal(res[1], stats.entropy(x, y[1, ...]))
117108

118-
@array_api_compatible
119109
def test_entropy_shape_mismatch(self, xp):
120110
x = xp.ones((10, 1, 12))
121111
y = xp.ones((11, 2))
122112
message = "Array shapes are incompatible for broadcasting."
123113
with pytest.raises(ValueError, match=message):
124114
stats.entropy(x, y)
125115

126-
@array_api_compatible
127116
def test_input_validation(self, xp):
128117
x = xp.ones(10)
129118
message = "`base` must be a positive number."
130119
with pytest.raises(ValueError, match=message):
131120
stats.entropy(x, base=-2)
132121

133122

123+
@pytest.mark.skip_xp_backends("dask.array", reason="No sorting in Dask")
134124
@array_api_compatible
135125
@pytest.mark.usefixtures("skip_xp_backends")
136126
class TestDifferentialEntropy:

0 commit comments

Comments
 (0)