Skip to content

Commit 8feff4c

Browse files
committed
WIP
1 parent dd361d0 commit 8feff4c

File tree

3 files changed

+114
-26
lines changed

3 files changed

+114
-26
lines changed

src/array_api_extra/_lib/_lazy.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from types import ModuleType
1010
from typing import TYPE_CHECKING, Any, cast, overload
1111

12+
from ._utils import _compat
1213
from ._utils._compat import (
1314
array_namespace,
1415
is_array_api_obj,
@@ -295,7 +296,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
295296
)
296297

297298
else:
298-
# Eager backends
299+
# Eager backends, including non-jitted JAX
299300
wrapped = _lazy_apply_wrapper(func, as_numpy, multi_output, xp)
300301
out = wrapped(*args, **kwargs)
301302

@@ -309,7 +310,7 @@ def _is_jax_jit_enabled(xp: ModuleType) -> bool: # numpydoc ignore=PR01,RT01
309310
x = xp.asarray(False)
310311
try:
311312
return bool(x)
312-
except jax.errors.TracerArrayConversionError:
313+
except jax.errors.TracerBoolConversionError:
313314
return True
314315

315316

@@ -362,14 +363,16 @@ def _lazy_apply_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR0
362363
def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
363364
*args: Array, **kwargs: Any
364365
) -> tuple[Array, ...]: # numpydoc ignore=GL08
366+
device = _compat.device(args[0]) if args else None
367+
365368
if as_numpy:
366369
args = _as_numpy(args)
367370
kwargs = _as_numpy(kwargs)
368371
out = func(*args, **kwargs)
369372

370373
if multi_output:
371374
assert isinstance(out, Sequence)
372-
return tuple(xp.asarray(o) for o in out)
373-
return (xp.asarray(out),)
375+
return tuple(xp.asarray(o, device=device) for o in out)
376+
return (xp.asarray(out, device=device),)
374377

375378
return wrapper

tests/conftest.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@ def xp(
110110
if library == Backend.NUMPY_READONLY:
111111
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
112112
xp = pytest.importorskip(library.value)
113+
# Possibly wrap module with array_api_compat
114+
xp = array_namespace(xp.empty(0))
113115

116+
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
117+
# in the global scope of the module containing the test function.
114118
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
115119

116120
if library == Backend.JAX:
@@ -119,8 +123,18 @@ def xp(
119123
# suppress unused-ignore to run mypy in -e lint as well as -e dev
120124
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
121125

122-
# Possibly wrap module with array_api_compat
123-
return array_namespace(xp.empty(0))
126+
return xp
127+
128+
129+
@pytest.fixture(params=[Backend.DASK]) # Select the test with `pytest -k dask`
130+
def da(
131+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
132+
) -> ModuleType: # numpydoc ignore=PR01,RT01
133+
"""Fixture that returns dask.array (wrapped by array-api-compat)."""
134+
xp = pytest.importorskip("dask.array")
135+
xp = array_namespace(xp.empty(0))
136+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
137+
return xp
124138

125139

126140
@pytest.fixture

tests/test_lazy.py

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,51 +10,120 @@
1010
from array_api_extra._lib._utils._typing import Array
1111
from array_api_extra.testing import lazy_xp_function
1212

13-
skip_as_numpy = [
14-
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host transfer"),
15-
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
16-
]
13+
as_numpy = pytest.mark.parametrize(
14+
"as_numpy",
15+
[
16+
False,
17+
pytest.param(
18+
True,
19+
marks=[
20+
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
21+
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
22+
],
23+
),
24+
],
25+
)
26+
27+
28+
@as_numpy
29+
def test_lazy_apply_simple(xp: ModuleType, as_numpy: bool):
30+
pytest.skip("TODO")
31+
32+
33+
@as_numpy
34+
def test_lazy_apply_broadcast(xp: ModuleType, as_numpy: bool):
35+
pytest.skip("TODO")
36+
37+
38+
@as_numpy
39+
def test_lazy_apply_multi_output(xp: ModuleType, as_numpy: bool):
40+
pytest.skip("TODO")
41+
42+
43+
def test_lazy_apply_core_indices(da: ModuleType):
44+
"""Test that a func that performs reductions along axes does so
45+
globally and not locally to each Dask chunk.
46+
"""
47+
pytest.skip("TODO")
48+
49+
50+
def test_lazy_apply_dont_run_on_meta(da: ModuleType):
51+
"""Test that Dask won't try running func on the meta array,
52+
as it may have minimum size requirements.
53+
"""
54+
pytest.skip("TODO")
55+
56+
57+
def test_lazy_apply_none_shape(da: ModuleType):
58+
pytest.skip("TODO")
59+
60+
61+
@as_numpy
62+
def test_lazy_apply_device(xp: ModuleType, as_numpy: bool):
63+
pytest.skip("TODO")
64+
65+
66+
@as_numpy
67+
def test_lazy_apply_no_args(xp: ModuleType, as_numpy: bool):
68+
pytest.skip("TODO")
1769

1870

1971
class NT(NamedTuple):
2072
a: Array
2173

2274

23-
def check_lazy_apply_kwargs(x: Array, expect: type, as_numpy: bool) -> Array:
24-
def f(
75+
def check_lazy_apply_kwargs(x: Array, expect_cls: type, as_numpy: bool) -> Array:
76+
def eager(
2577
x: Array,
2678
z: dict[str, list[Array] | tuple[Array, ...] | NT],
2779
msg: str,
2880
msgs: list[str],
81+
scalar: int,
2982
) -> Array:
30-
assert isinstance(x, expect)
83+
assert isinstance(x, expect_cls)
84+
assert int(x) == 0 # JAX will crash if x isn't material
85+
# Did we re-wrap the namedtuple correctly, or did it get
86+
# accidentally changed to a basic tuple?
3187
assert isinstance(z["foo"], NT)
32-
assert isinstance(z["foo"].a, expect)
33-
assert isinstance(z["bar"][0], expect)
34-
assert isinstance(z["baz"][0], expect)
35-
assert msg == "Hello World"
36-
assert msgs[0] == "Hello World"
37-
return x + 1
88+
assert isinstance(z["foo"].a, expect_cls)
89+
assert isinstance(z["bar"][0], expect_cls) # list
90+
assert isinstance(z["baz"][0], expect_cls) # tuple
91+
assert msg == "Hello World" # must be hidden from JAX
92+
assert msgs[0] == "Hello World" # must be hidden from JAX
93+
assert isinstance(msg, str)
94+
assert isinstance(msgs[0], str)
95+
assert scalar == 1 # must be hidden from JAX
96+
assert isinstance(scalar, int)
97+
return x + 1 # type: ignore[operator]
3898

3999
return lazy_apply( # pyright: ignore[reportCallIssue]
40-
f,
100+
eager,
41101
x,
102+
# These kwargs can and should be passed through jax.pure_callback
42103
z={"foo": NT(x), "bar": [x], "baz": (x,)},
104+
# These can't
43105
msg="Hello World",
44106
msgs=["Hello World"],
107+
# This will be automatically cast to jax.Array if we don't wrap it
108+
scalar=1,
45109
shape=x.shape,
46110
dtype=x.dtype,
47111
as_numpy=as_numpy,
48112
)
49113

50-
lazy_xp_function(check_lazy_apply_kwargs, static_argnames=("expect", "as_numpy"))
114+
115+
lazy_xp_function(check_lazy_apply_kwargs, static_argnames=("expect_cls", "as_numpy"))
51116

52117

53-
@pytest.mark.parametrize("as_numpy", [False, pytest.param(True, marks=skip_as_numpy)])
118+
@as_numpy
54119
def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool) -> None:
55-
expect = np.ndarray if as_numpy or library is Backend.DASK else type(xp.asarray(0))
120+
"""When as_numpy=True, search and replace arrays in the (nested) keywords arguments
121+
with numpy arrays, and leave the rest untouched."""
122+
expect_cls = (
123+
np.ndarray if as_numpy or library is Backend.DASK else type(xp.asarray(0))
124+
)
56125
x = xp.asarray(0)
57-
actual = check_lazy_apply_kwargs(x, expect, as_numpy)
126+
actual = check_lazy_apply_kwargs(x, expect_cls, as_numpy) # pyright: ignore[reportUnknownArgumentType]
58127
xp_assert_equal(actual, x + 1)
59128

60129

@@ -69,10 +138,12 @@ def eager(_: Array) -> Array:
69138

70139
return lazy_apply(eager, x, shape=x.shape, dtype=x.dtype)
71140

72-
lazy_xp_function(raises)
141+
142+
# jax.pure_callback does not support raising
143+
# https://github.com/jax-ml/jax/issues/26102
144+
lazy_xp_function(raises, jax_jit=False)
73145

74146

75-
@pytest.mark.skip_xp_backend(Backend.JAX_JIT, reason="no exception support")
76147
def test_lazy_apply_raises(xp: ModuleType) -> None:
77148
x = xp.asarray(0)
78149

0 commit comments

Comments
 (0)