Skip to content
Discussion options

You must be logged in to vote

jax.lax.stop_gradient does exactly this for automatic differentiation, e.g., _ = g(jax.lax.stop_gradient(x))

If you also want to avoid tracing with other JAX transformations, you may want to experiment with host_callback: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html

Note: I'm presuming that you use the result of g(x) in some way, rather than just throwing it away. If not, you may run into other potential trouble, because in general JAX expects that transformations are applied to pure functions -- and functions whose return value is not used are typically used for their side-effects

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by fishjojo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants