Skip to content

Commit b99ca91

Browse files
committed
Don't run jax.pure_callback unless necessary
1 parent de67dc1 commit b99ca91

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

src/array_api_extra/_lib/_lazy.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
129129
When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
130130
contain any `None` elements.
131131
132+
.. warning::
133+
134+
`func` must never raise if it's run inside `jax.jit`, as its behavior is
135+
undefined.
136+
132137
Using this with `as_numpy=False` is particularly useful to apply non-jittable
133138
JAX functions to arrays on GPU devices.
134139
If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
@@ -254,12 +259,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
254259
for i, (shape, dtype) in enumerate(zip(shapes, dtypes, strict=True))
255260
)
256261

257-
elif is_jax_namespace(xp):
258-
# If we're inside jax.jit, we can't eagerly convert
259-
# the JAX tracer objects to numpy.
260-
# Instead, we delay calling wrapped, which will receive
261-
# as arguments and will return JAX eager arrays.
262-
262+
elif is_jax_namespace(xp) and _is_jax_jit_enabled(xp):
263+
# Delay calling func with jax.pure_callback, which will forward to func eager
264+
# JAX arrays. Do not use jax.pure_callback when running outside of the JIT,
265+
# as it does not support raising exceptions:
266+
# https://github.com/jax-ml/jax/issues/26102
263267
import jax
264268

265269
# Shield eager kwargs from being coerced into JAX arrays.
@@ -276,27 +280,19 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
276280
wrapped = _lazy_apply_wrapper(
277281
partial(func, **eager_kwargs), as_numpy, multi_output, xp
278282
)
279-
280-
if any(s is None for shape in shapes for s in shape):
281-
# Unknown output shape. Won't work with jax.jit, but it
282-
# can work with eager jax.
283-
# Raises jax.errors.TracerArrayConversionError if we're inside jax.jit.
284-
out = wrapped(*args, **lazy_kwargs)
285-
286-
else:
287-
# suppress unused-ignore to run mypy in -e lint as well as -e dev
288-
out = cast( # type: ignore[bad-cast,unused-ignore]
289-
tuple[Array, ...],
290-
jax.pure_callback(
291-
wrapped,
292-
tuple(
293-
jax.ShapeDtypeStruct(shape, dtype) # pyright: ignore[reportUnknownArgumentType]
294-
for shape, dtype in zip(shapes, dtypes, strict=True)
295-
),
296-
*args,
297-
**lazy_kwargs,
283+
# suppress unused-ignore to run mypy in -e lint as well as -e dev
284+
out = cast( # type: ignore[bad-cast,unused-ignore]
285+
tuple[Array, ...],
286+
jax.pure_callback(
287+
wrapped,
288+
tuple(
289+
jax.ShapeDtypeStruct(shape, dtype) # pyright: ignore[reportUnknownArgumentType]
290+
for shape, dtype in zip(shapes, dtypes, strict=True)
298291
),
299-
)
292+
*args,
293+
**lazy_kwargs,
294+
),
295+
)
300296

301297
else:
302298
# Eager backends
@@ -306,6 +302,17 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
306302
return out if multi_output else out[0]
307303

308304

305+
def _is_jax_jit_enabled(xp: ModuleType) -> bool: # numpydoc ignore=PR01,RT01
306+
"""Return True if this function is being called inside ``jax.jit``."""
307+
import jax # pylint: disable=import-outside-toplevel
308+
309+
x = xp.asarray(False)
310+
try:
311+
return bool(x)
312+
except jax.errors.TracerArrayConversionError:
313+
return True
314+
315+
309316
def _contains_jax_arrays(x: object) -> bool: # numpydoc ignore=PR01,RT01
310317
"""
311318
Test if x is a JAX array or a nested collection with any JAX arrays in it.

tests/test_lazy.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,14 @@ def eager(_: Array) -> Array:
6363

6464
return lazy_apply(eager, x, shape=x.shape, dtype=x.dtype)
6565

66-
6766
lazy_xp_function(raises)
6867

6968

70-
def test_lazy_apply_raises(xp: ModuleType, library: Backend) -> None:
69+
@pytest.mark.skip_xp_backend(Backend.JAX_JIT, reason="no exception support")
70+
def test_lazy_apply_raises(xp: ModuleType) -> None:
7171
x = xp.asarray(0)
7272

73-
with pytest.raises(
74-
# FIXME https://github.com/jax-ml/jax/issues/26102
75-
RuntimeError if library is Backend.JAX else CustomError,
76-
match="Hello World",
77-
):
73+
with pytest.raises(CustomError, match="Hello World"):
7874
# Here we are disregarding the return value, which would
7975
# normally cause the graph not to materialize and the
8076
# exception not to be raised.

0 commit comments

Comments
 (0)