Skip to content

Commit 5490016

Browse files
committed
ndonnx support
1 parent ddb03e6 commit 5490016

File tree

3 files changed

+21
-14
lines changed

3 files changed

+21
-14
lines changed

array_api_compat/common/_helpers.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -627,17 +627,13 @@ def device(x: Array, /) -> Device:
627627
to_device : Move array data to a different device.
628628
629629
"""
630-
if is_numpy_array(x):
630+
if is_numpy_array(x) or is_ndonnx_array(x):
631631
return "cpu"
632632
elif is_dask_array(x):
633633
# Peek at the metadata of the jax array to determine type
634-
try:
635-
import numpy as np
636-
if isinstance(x._meta, np.ndarray):
637-
# Must be on CPU since backed by numpy
638-
return "cpu"
639-
except ImportError:
640-
pass
634+
if is_numpy_array(x._meta):
635+
# Must be on CPU since backed by numpy
636+
return "cpu"
641637
return _DASK_DEVICE
642638
elif is_jax_array(x):
643639
# JAX has .device() as a method, but it is being deprecated so that it
@@ -758,7 +754,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
758754
device : Hardware device the array data resides on.
759755
760756
"""
761-
if is_numpy_array(x):
757+
if is_numpy_array(x) or is_ndonnx_array(x):
762758
if stream is not None:
763759
raise ValueError("The stream argument to to_device() is not supported")
764760
if device == 'cpu':
@@ -780,7 +776,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
780776
if not hasattr(x, "__array_namespace__"):
781777
# In JAX v0.4.31 and older, this import adds to_device method to x.
782778
import jax.experimental.array_api # noqa: F401
783-
return x.to_device(device, stream=stream)
784779
elif is_pydata_sparse_array(x) and device == _device(x):
785780
# Perform trivial check to return the same array if
786781
# device is same instead of err-ing.

tests/test_array_namespace.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
@pytest.mark.parametrize("use_compat", [True, False, None])
1616
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
17-
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
17+
@pytest.mark.parametrize("library", all_libraries)
1818
def test_array_namespace(library, api_version, use_compat):
1919
xp = import_(library)
2020

2121
array = xp.asarray([1.0, 2.0, 3.0])
22-
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
22+
if use_compat and library in {'array_api_strict', 'jax.numpy', 'ndonnx', 'sparse'}:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
25+
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
26+
pytest.skip("Unsupported API version")
27+
2528
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
2629

2730
if use_compat is False or use_compat is None and library not in wrapped_libraries:

tests/test_common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,21 @@ def test_asarray_cross_library(source_library, target_library, request):
144144

145145
# TODO: remove xfail once
146146
# https://github.com/dask/dask/issues/8260 is resolved
147-
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
148-
if source_library == "cupy" and target_library != "cupy":
147+
request.node.add_marker(
148+
pytest.mark.xfail(reason="Bug in dask raising error on conversion")
149+
)
150+
elif source_library == "ndonnx" and target_library not in ("numpy", "array_api_strict"):
151+
request.node.add_marker(
152+
pytest.mark.xfail(
153+
reason="The truth value of lazy Array Array(dtype=Boolean) is unknown"
154+
)
155+
)
156+
elif source_library == "cupy" and target_library != "cupy":
149157
# cupy explicitly disallows implicit conversions to CPU
150158
pytest.skip(reason="cupy does not support implicit conversion to CPU")
151159
elif source_library == "sparse" and target_library != "sparse":
152160
pytest.skip(reason="`sparse` does not allow implicit densification")
161+
153162
src_lib = import_(source_library, wrapper=True)
154163
tgt_lib = import_(target_library, wrapper=True)
155164
is_tgt_type = globals()[is_array_functions[target_library]]

0 commit comments

Comments
 (0)