Skip to content

Commit 06afd75

Browse files
committed
fix arange test
1 parent 6305d7e commit 06afd75

File tree

7 files changed

+47
-14
lines changed

7 files changed

+47
-14
lines changed

array_api_compat/_dask_ci_shim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""
22
A little CI shim for the dask backend that
33
disables the dask scheduler
4+
5+
It also lets you see the dask dashboard for debugging
6+
at http://127.0.0.1:8787/status
47
"""
58
import dask
69
dask.config.set(scheduler='synchronous')
710

811
from dask.distributed import Client
912
_client = Client()
10-
print(_client.dashboard_link)
1113

1214
from .dask import *
1315
from .dask import __array_api_version__

array_api_compat/common/_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def _is_dask_array(x):
4747

4848
import dask.array
4949

50-
# TODO: Should we reject ndarray subclasses?
5150
return isinstance(x, dask.array.Array)
5251

5352
def is_array_api_obj(x):

array_api_compat/dask/_aliases.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from __future__ import annotations
2+
13
from ..common import _aliases
4+
from ..common._helpers import _check_device
25

36
from .._internal import get_xp
47

@@ -30,13 +33,44 @@
3033
result_type,
3134
)
3235

36+
from typing import TYPE_CHECKING
37+
if TYPE_CHECKING:
38+
from typing import Optional, Union
39+
from ..common._typing import ndarray, Device, Dtype
40+
3341
import dask.array as da
3442

3543
isdtype = get_xp(np)(_aliases.isdtype)
3644
astype = _aliases.astype
3745

3846
# Common aliases
39-
arange = get_xp(da)(_aliases.arange)
47+
48+
# This arange func is modified from the common one to
49+
# not pass stop/step as keyword arguments, which will cause
50+
# an error with dask
51+
def dask_arange(
52+
start: Union[int, float],
53+
/,
54+
stop: Optional[Union[int, float]] = None,
55+
step: Union[int, float] = 1,
56+
*,
57+
xp,
58+
dtype: Optional[Dtype] = None,
59+
device: Optional[Device] = None,
60+
**kwargs
61+
) -> ndarray:
62+
_check_device(xp, device)
63+
args = [start]
64+
if stop is not None:
65+
args.append(stop)
66+
else:
67+
# stop is None, so start is actually stop
68+
# prepend the default value for start which is 0
69+
args.insert(0, 0)
70+
args.append(step)
71+
return xp.arange(*args, dtype=dtype, **kwargs)
72+
73+
arange = get_xp(da)(dask_arange)
4074
eye = get_xp(da)(_aliases.eye)
4175

4276
from functools import partial

dask-skips.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
11
# FFT isn't conformant
22
array_api_tests/test_fft.py
3-
4-
# Errors with dask, also makes Dask go OOM
5-
array_api_tests/test_creation_functions.py::test_arange

dask-xfails.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ array_api_tests/test_creation_functions.py::test_eye
2525
# finfo(float32).eps returns float32 but should return float
2626
array_api_tests/test_data_type_functions.py::test_finfo[float32]
2727

28-
# shape mismatch
28+
# out[-1]=dask.aray<getitem ...> but should be some floating number
29+
# (I think the test is not forcing the op to be computed?)
2930
array_api_tests/test_creation_functions.py::test_linspace
3031

3132
# out=-0, but should be +0

tests/test_vendoring.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ def test_vendoring_torch():
1818
from vendor_test import uses_torch
1919
uses_torch._test_torch()
2020

21-
def test_vendoring_torch():
22-
from vendor_test import uses_torch
23-
uses_torch._test_torch()
21+
def test_vendoring_dask():
22+
from vendor_test import uses_dask
23+
uses_dask._test_dask()

vendor_test/uses_dask.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import dask.array as da
66
import numpy as np
77

8-
def _test_numpy():
8+
def _test_dask():
99
a = dask_compat.asarray([1., 2., 3.])
1010
b = dask_compat.arange(3, dtype=dask_compat.float32)
1111

1212
# np.pow does not exist. Update this to use something else if it is added
1313
res = dask_compat.pow(a, b)
1414
assert res.dtype == dask_compat.float64 == np.float64
15-
assert isinstance(a, da.array)
16-
assert isinstance(b, da.array)
17-
assert isinstance(res, da.array)
15+
assert isinstance(a, da.Array)
16+
assert isinstance(b, da.Array)
17+
assert isinstance(res, da.Array)
1818

1919
np.testing.assert_allclose(res, [1., 2., 9.])

0 commit comments

Comments
 (0)