|
17 | 17 | from array_api_compat import ( |
18 | 18 | device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device |
19 | 19 | ) |
| 20 | +from array_api_compat.common._helpers import _DASK_DEVICE |
20 | 21 | from ._helpers import all_libraries, import_, wrapped_libraries, xfail |
21 | 22 |
|
22 | 23 |
|
@@ -189,23 +190,26 @@ class C: |
189 | 190 |
|
190 | 191 |
|
191 | 192 | @pytest.mark.parametrize("library", all_libraries) |
192 | | -def test_device(library, request): |
| 193 | +def test_device_to_device(library, request): |
193 | 194 | if library == "ndonnx": |
194 | | - xfail(request, reason="Needs ndonnx >=0.9.4") |
| 195 | + xfail(request, reason="Stub raises ValueError") |
| 196 | + if library == "sparse": |
| 197 | + xfail(request, reason="No __array_namespace_info__()") |
195 | 198 |
|
196 | 199 | xp = import_(library, wrapper=True) |
| 200 | + devices = xp.__array_namespace_info__().devices() |
197 | 201 |
|
198 | | - # We can't test much for device() and to_device() other than that |
199 | | - # x.to_device(x.device) works. |
200 | | - |
| 202 | + # Default device |
201 | 203 | x = xp.asarray([1, 2, 3]) |
202 | 204 | dev = device(x) |
203 | 205 |
|
204 | | - x2 = to_device(x, dev) |
205 | | - assert device(x2) == device(x) |
206 | | - |
207 | | - x3 = xp.asarray(x, device=dev) |
208 | | - assert device(x3) == device(x) |
| 206 | + for dev in devices: |
| 207 | + if dev is None: # JAX >=0.5.3 |
| 208 | + continue |
| 209 | + if dev is _DASK_DEVICE: # TODO this needs a better design |
| 210 | + continue |
| 211 | + y = to_device(x, dev) |
| 212 | + assert device(y) == dev |
209 | 213 |
|
210 | 214 |
|
211 | 215 | @pytest.mark.parametrize("library", wrapped_libraries) |
|
0 commit comments