Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

if not capabilities(xp)["boolean indexing"]:
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)

Expand Down Expand Up @@ -716,7 +716,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
# 2. backend has unique_counts and it returns a None-sized array;
# e.g. Dask, ndonnx
# 3. backend does not have unique_counts; e.g. wrapped JAX
if capabilities(xp)["data-dependent shapes"]:
if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]:
# xp has unique_counts; O(n) complexity
_, counts = xp.unique_counts(x)
n = _compat.size(counts)
Expand Down
6 changes: 5 additions & 1 deletion src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

if is_torch_namespace(xp):
array = to_device(array, "cpu")
if array.device.type == "meta": # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
# Can't materialize; generate dummy data instead
array = xp.zeros_like(array, device="cpu")
else:
array = to_device(array, "cpu")
if is_array_api_strict_namespace(xp):
cpu: Device = xp.Device("CPU_DEVICE")
array = to_device(array, cpu)
Expand Down
14 changes: 12 additions & 2 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
is_jax_namespace,
is_numpy_array,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._typing import Array
from ._typing import Array, Device

if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.12 and >=3.13)
Expand Down Expand Up @@ -300,7 +301,7 @@ def meta_namespace(
return array_namespace(*metas)


def capabilities(xp: ModuleType) -> dict[str, int]:
def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]:
"""
Return patched ``xp.__array_namespace_info__().capabilities()``.

Expand All @@ -311,6 +312,8 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
----------
xp : array_namespace
The standard-compatible namespace.
device : Device, optional
The device to use.

Returns
-------
Expand All @@ -326,6 +329,13 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
# Fixed in jax >=0.6.0
out = out.copy()
out["boolean indexing"] = False
if is_torch_namespace(xp):
# FIXME https://github.com/data-apis/array-api/issues/945
device = xp.get_default_device() if device is None else xp.device(device)
if cast(Any, device).type == "meta": # type: ignore[explicit-any]
out = out.copy()
out["boolean indexing"] = False
out["data-dependent shapes"] = False
return out


Expand Down
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def device(
Where possible, return a device that is not the default one.
"""
if library == Backend.ARRAY_API_STRICT:
d = xp.Device("device1")
assert get_device(xp.empty(0)) != d
return d
return xp.Device("device1")
if library == Backend.TORCH:
return xp.device("meta")
if library == Backend.TORCH_GPU:
return xp.device("cpu")
return get_device(xp.empty(0))
6 changes: 3 additions & 3 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,6 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
res = isclose(a, b, equal_nan=equal_nan)
assert get_device(res) == device
xp_assert_equal(
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
)


class TestKron:
Expand Down Expand Up @@ -996,6 +993,9 @@ def test_all_python_scalars(self, assume_unique: bool):
_ = setdiff1d(0, 0, assume_unique=assume_unique)

@assume_unique
@pytest.mark.skip_xp_backend(
Backend.TORCH, reason="device='meta' does not support unknown shapes"
)
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
Expand Down
30 changes: 25 additions & 5 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType):
assert meta_namespace(*args, xp=xp) in (xp, np_compat)


def test_capabilities(xp: ModuleType):
expect = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect.add("max dimensions")
assert capabilities(xp).keys() == expect
class TestCapabilities:
def test_basic(self, xp: ModuleType):
expect = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect.add("max dimensions")
assert capabilities(xp).keys() == expect

def test_device(self, xp: ModuleType, library: Backend, device: Device):
expect_keys = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect_keys.add("max dimensions")
assert capabilities(xp, device=device).keys() == expect_keys

if library.like(Backend.TORCH):
# The output of capabilities is device-specific.

# Test that device=None gets the current default device.
expect = capabilities(xp, device=device)
with xp.device(device):
actual = capabilities(xp)
assert actual == expect

# Test that we're accepting anything that is accepted by the
# device= parameter in other functions
actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]


class Wrapper(Generic[T]):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
Backend.ARRAY_API_STRICT, reason="device->host copy"
),
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
pytest.mark.skip_xp_backend(
Backend.TORCH, reason="materialize 'meta' device"
),
pytest.mark.skip_xp_backend(
Backend.TORCH_GPU, reason="device->host copy"
),
Expand Down
20 changes: 16 additions & 4 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,22 @@
)


def test_as_numpy_array(xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
y = as_numpy_array(x, xp=xp)
assert isinstance(y, np.ndarray)
class TestAsNumPyArray:
def test_basic(self, xp: ModuleType):
x = xp.asarray([1, 2, 3])
y = as_numpy_array(x, xp=xp)
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

def test_device(self, xp: ModuleType, library: Backend, device: Device):
x = xp.asarray([1, 2, 3], device=device)
actual = as_numpy_array(x, xp=xp)
if library is Backend.TORCH:
assert device.type == "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
expect = np.asarray([0, 0, 0])
else:
expect = np.asarray([1, 2, 3])

xp_assert_equal(actual, expect) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
Expand Down