Skip to content
Discussion options

You must be logged in to vote

Option 1

model_output = model.apply(params, model_input)
detached_model_output = jax.lax.stop_gradient(model_output)

Option 2

detached_params = jax.lax.stop_gradient(params)
model_output = model.apply(detached_params , model_input)

Note that params can be a pytree, i.e. nested dict/tuple/list or custom pytree type.
This is suitable for flax and haiku: init will give you params as a pytree, apply will request a pytree params.
Option 3
Use https://github.com/patrick-kidger/equinox, then your model will be a pytree.

Replies: 4 comments 4 replies

Comment options

You must be logged in to vote
4 replies
@maobenz
Comment options

@YouJiacheng
Comment options

@maobenz
Comment options

@YouJiacheng
Comment options

Answer selected by maobenz
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants