@@ -129,6 +129,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
129129 When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
130130 contain any `None` elements.
131131
132+ .. warning::
133+
134+ `func` must never raise if it's run inside `jax.jit`, as its behavior is
135+ undefined.
136+
132137 Using this with `as_numpy=False` is particularly useful to apply non-jittable
133138 JAX functions to arrays on GPU devices.
134139 If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
@@ -254,12 +259,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
254259 for i , (shape , dtype ) in enumerate (zip (shapes , dtypes , strict = True ))
255260 )
256261
257- elif is_jax_namespace (xp ):
258- # If we're inside jax.jit, we can't eagerly convert
259- # the JAX tracer objects to numpy.
260- # Instead, we delay calling wrapped, which will receive
261- # as arguments and will return JAX eager arrays.
262-
262+ elif is_jax_namespace (xp ) and _is_jax_jit_enabled (xp ):
263+ # Delay calling func with jax.pure_callback, which will forward to func eager
264+ # JAX arrays. Do not use jax.pure_callback when running outside of the JIT,
265+ # as it does not support raising exceptions:
266+ # https://github.com/jax-ml/jax/issues/26102
263267 import jax
264268
265269 # Shield eager kwargs from being coerced into JAX arrays.
@@ -276,27 +280,19 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
276280 wrapped = _lazy_apply_wrapper (
277281 partial (func , ** eager_kwargs ), as_numpy , multi_output , xp
278282 )
279-
280- if any (s is None for shape in shapes for s in shape ):
281- # Unknown output shape. Won't work with jax.jit, but it
282- # can work with eager jax.
283- # Raises jax.errors.TracerArrayConversionError if we're inside jax.jit.
284- out = wrapped (* args , ** lazy_kwargs )
285-
286- else :
287- # suppress unused-ignore to run mypy in -e lint as well as -e dev
288- out = cast ( # type: ignore[bad-cast,unused-ignore]
289- tuple [Array , ...],
290- jax .pure_callback (
291- wrapped ,
292- tuple (
293- jax .ShapeDtypeStruct (shape , dtype ) # pyright: ignore[reportUnknownArgumentType]
294- for shape , dtype in zip (shapes , dtypes , strict = True )
295- ),
296- * args ,
297- ** lazy_kwargs ,
283+ # suppress unused-ignore to run mypy in -e lint as well as -e dev
284+ out = cast ( # type: ignore[bad-cast,unused-ignore]
285+ tuple [Array , ...],
286+ jax .pure_callback (
287+ wrapped ,
288+ tuple (
289+ jax .ShapeDtypeStruct (shape , dtype ) # pyright: ignore[reportUnknownArgumentType]
290+ for shape , dtype in zip (shapes , dtypes , strict = True )
298291 ),
299- )
292+ * args ,
293+ ** lazy_kwargs ,
294+ ),
295+ )
300296
301297 else :
302298 # Eager backends
@@ -306,6 +302,17 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
306302 return out if multi_output else out [0 ]
307303
308304
305+ def _is_jax_jit_enabled (xp : ModuleType ) -> bool : # numpydoc ignore=PR01,RT01
306+ """Return True if this function is being called inside ``jax.jit``."""
307+ import jax # pylint: disable=import-outside-toplevel
308+
309+ x = xp .asarray (False )
310+ try :
311+ return bool (x )
312+ except jax .errors .TracerArrayConversionError :
313+ return True
314+
315+
309316def _contains_jax_arrays (x : object ) -> bool : # numpydoc ignore=PR01,RT01
310317 """
311318 Test if x is a JAX array or a nested collection with any JAX arrays in it.
0 commit comments