-
Hi,
You can see that Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
If you also want to avoid tracing with other JAX transformations, you may want to experiment with Note: I'm presuming that you use the result of |
Beta Was this translation helpful? Give feedback.
jax.lax.stop_gradient
does exactly this for automatic differentiation, e.g.,_ = g(jax.lax.stop_gradient(x))
If you also want to avoid tracing with other JAX transformations, you may want to experiment with
host_callback
: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.htmlNote: I'm presuming that you use the result of
g(x)
in some way, rather than just throwing it away. If not, you may run into other potential trouble, because in general JAX expects that transformations are applied to pure functions -- and functions whose return value is not used are typically used for their side-effects