@@ -463,7 +463,7 @@ def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: #
463463 Notes
464464 -----
465465 The `instances` iterable must yield at least the same number of elements as the ones
466- returned by ``pickle_without ``, but the elements do not need to be the same objects
466+ returned by ``pickle_flatten ``, but the elements do not need to be the same objects
467467 or even the same types of objects. Excess elements, if any, will be left untouched.
468468 """
469469 iters = iter (instances ), iter (rest )
@@ -540,6 +540,25 @@ def jax_autojit(
540540 See Also
541541 --------
542542 jax.jit : JAX JIT compilation function.
543+
544+ Notes
545+ -----
546+ These are useful choices *for testing purposes only*, which is how this function is
547+ intended to be used. The output of ``jax.jit`` is a C++ level callable, that
548+ directly dispatches to the compiled kernel after the initial call. In comparison,
549+ ``jax_autojit`` incurs a much higher dispatch time.
550+
551+ Additionally, consider::
552+
553+ def f(x: Array, y: float, plus: bool) -> Array:
554+ return x + y if plus else x - y
555+
556+ j1 = jax.jit(f, static_argnames="plus")
557+ j2 = jax_autojit(f)
558+
559+ In the above example, ``j2`` requires a lot less setup to be tested effectively than
560+ ``j1``, but on the flip side it means that it will be re-traced for every different
561+ value of ``y``, which likely makes it not fit for purpose in production.
543562 """
544563 import jax
545564
0 commit comments