Skip to content

Commit fd7a799

Browse files
committed
ENH: support PyTorch device='meta'
1 parent e4ecb82 commit fd7a799

File tree

8 files changed

+71
-20
lines changed

8 files changed

+71
-20
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
153153
) -> Array:
154154
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
155155

156-
if not capabilities(xp)["boolean indexing"]:
156+
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
157157
# jax.jit does not support assignment by boolean mask
158158
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
159159

@@ -716,7 +716,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
716716
# 2. backend has unique_counts and it returns a None-sized array;
717717
# e.g. Dask, ndonnx
718718
# 3. backend does not have unique_counts; e.g. wrapped JAX
719-
if capabilities(xp)["data-dependent shapes"]:
719+
if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]:
720720
# xp has unique_counts; O(n) complexity
721721
_, counts = xp.unique_counts(x)
722722
n = _compat.size(counts)

src/array_api_extra/_lib/_testing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
100100
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
101101

102102
if is_torch_namespace(xp):
103-
array = to_device(array, "cpu")
103+
if array.device.type == "meta": # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
104+
# Can't materialize; generate dummy data instead
105+
array = xp.zeros_like(array, device="cpu")
106+
else:
107+
array = to_device(array, "cpu")
104108
if is_array_api_strict_namespace(xp):
105109
cpu: Device = xp.Device("CPU_DEVICE")
106110
array = to_device(array, cpu)

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
is_jax_namespace,
3030
is_numpy_array,
3131
is_pydata_sparse_namespace,
32+
is_torch_namespace,
3233
)
33-
from ._typing import Array
34+
from ._typing import Array, Device
3435

3536
if TYPE_CHECKING: # pragma: no cover
3637
# TODO import from typing (requires Python >=3.12 and >=3.13)
@@ -300,7 +301,7 @@ def meta_namespace(
300301
return array_namespace(*metas)
301302

302303

303-
def capabilities(xp: ModuleType) -> dict[str, int]:
304+
def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]:
304305
"""
305306
Return patched ``xp.__array_namespace_info__().capabilities()``.
306307
@@ -311,6 +312,8 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
311312
----------
312313
xp : array_namespace
313314
The standard-compatible namespace.
315+
device : Device, optional
316+
The device to use.
314317
315318
Returns
316319
-------
@@ -326,6 +329,13 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
326329
# Fixed in jax >=0.6.0
327330
out = out.copy()
328331
out["boolean indexing"] = False
332+
if is_torch_namespace(xp):
333+
# FIXME https://github.com/data-apis/array-api/issues/945
334+
device = xp.get_default_device() if device is None else xp.device(device)
335+
if cast(Any, device).type == "meta": # type: ignore[explicit-any]
336+
out = out.copy()
337+
out["boolean indexing"] = False
338+
out["data-dependent shapes"] = False
329339
return out
330340

331341

tests/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def device(
211211
Where possible, return a device that is not the default one.
212212
"""
213213
if library == Backend.ARRAY_API_STRICT:
214-
d = xp.Device("device1")
215-
assert get_device(xp.empty(0)) != d
216-
return d
214+
return xp.Device("device1")
215+
if library == Backend.TORCH:
216+
return xp.device("meta")
217+
if library == Backend.TORCH_GPU:
218+
return xp.device("cpu")
217219
return get_device(xp.empty(0))

tests/test_funcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,9 +731,6 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
731731
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
732732
res = isclose(a, b, equal_nan=equal_nan)
733733
assert get_device(res) == device
734-
xp_assert_equal(
735-
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
736-
)
737734

738735

739736
class TestKron:
@@ -996,6 +993,9 @@ def test_all_python_scalars(self, assume_unique: bool):
996993
_ = setdiff1d(0, 0, assume_unique=assume_unique)
997994

998995
@assume_unique
996+
@pytest.mark.skip_xp_backend(
997+
Backend.TORCH, reason="device='meta' does not support unknown shapes"
998+
)
999999
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
10001000
x1 = xp.asarray([3, 8, 20], device=device)
10011001
x2 = xp.asarray([2, 3, 4], device=device)

tests/test_helpers.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType):
212212
assert meta_namespace(*args, xp=xp) in (xp, np_compat)
213213

214214

215-
def test_capabilities(xp: ModuleType):
216-
expect = {"boolean indexing", "data-dependent shapes"}
217-
if xp.__array_api_version__ >= "2024.12":
218-
expect.add("max dimensions")
219-
assert capabilities(xp).keys() == expect
215+
class TestCapabilities:
216+
def test_basic(self, xp: ModuleType):
217+
expect = {"boolean indexing", "data-dependent shapes"}
218+
if xp.__array_api_version__ >= "2024.12":
219+
expect.add("max dimensions")
220+
assert capabilities(xp).keys() == expect
221+
222+
def test_device(self, xp: ModuleType, library: Backend, device: Device):
223+
expect_keys = {"boolean indexing", "data-dependent shapes"}
224+
if xp.__array_api_version__ >= "2024.12":
225+
expect_keys.add("max dimensions")
226+
assert capabilities(xp, device=device).keys() == expect_keys
227+
228+
if library.like(Backend.TORCH):
229+
# The output of capabilities is device-specific.
230+
231+
# Test that device=None gets the current default device.
232+
expect = capabilities(xp, device=device)
233+
with xp.device(device):
234+
actual = capabilities(xp)
235+
assert actual == expect
236+
237+
# Test that we're accepting anything that is accepted by the
238+
# device= parameter in other functions
239+
actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]
220240

221241

222242
class Wrapper(Generic[T]):

tests/test_lazy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
278278
Backend.ARRAY_API_STRICT, reason="device->host copy"
279279
),
280280
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
281+
pytest.mark.skip_xp_backend(
282+
Backend.TORCH, reason="materialize 'meta' device"
283+
),
281284
pytest.mark.skip_xp_backend(
282285
Backend.TORCH_GPU, reason="device->host copy"
283286
),

tests/test_testing.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,22 @@
3939
)
4040

4141

42-
def test_as_numpy_array(xp: ModuleType, device: Device):
43-
x = xp.asarray([1, 2, 3], device=device)
44-
y = as_numpy_array(x, xp=xp)
45-
assert isinstance(y, np.ndarray)
42+
class TestAsNumPyArray:
43+
def test_basic(self, xp: ModuleType):
44+
x = xp.asarray([1, 2, 3])
45+
y = as_numpy_array(x, xp=xp)
46+
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
47+
48+
def test_device(self, xp: ModuleType, library: Backend, device: Device):
49+
x = xp.asarray([1, 2, 3], device=device)
50+
actual = as_numpy_array(x, xp=xp)
51+
if library is Backend.TORCH:
52+
assert device.type == "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
53+
expect = np.asarray([0, 0, 0])
54+
else:
55+
expect = np.asarray([1, 2, 3])
56+
57+
xp_assert_equal(actual, expect) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
4658

4759

4860
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)

0 commit comments

Comments
 (0)