File tree Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -540,6 +540,25 @@ def jax_autojit(
540540 See Also
541541 --------
542542 jax.jit : JAX JIT compilation function.
543+
544+ Notes
545+ -----
546+ These are useful choices *for testing purposes only*, which is how this function is
547+ intended to be used. The output of ``jax.jit`` is a C++ level callable, that
548+ directly dispatches to the compiled kernel after the initial call. In comparison,
549+ ``jax_autojit`` incurs in a much higher dispatch time.
550+
551+ Additionally, consider::
552+
553+ def f(x: Array, y: float, plus: bool) -> Array:
554+ return x + y if plus else x - y
555+
556+ j1 = jax.jit(f, static_argnames="plus")
557+ j2 = jax_autojit(f)
558+
559+ In the above example, ``j2`` requires a lot less setup to be tested effectively than
560+ ``j1``, but on the flip side it means that it will be re-traced for every different
561+ value of ``y``, which likely makes it not fit for purpose in production.
543562 """
544563 import jax
545564
Original file line number Diff line number Diff line change @@ -96,6 +96,7 @@ def lazy_xp_function( # type: ignore[explicit-any]
9696 jax_jit : bool, optional
9797 Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after
9898 calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``.
99+ This is the default behaviour.
99100 Set to False if `func` is only compatible with eager (non-jitted) JAX.
100101
101102 Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX
You can’t perform that action at this time.
0 commit comments