Clip value of a variable in during Optimization Step? #5671
Unanswered
adam-hartshorne
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I think you need to do something like this: params = get_params(opt_state)
params = jax.tree_map(lambda x: jnp.clip(x, min_val, max_val), params)
# set params back into opt_state Which framework are you using? I think this is a bit nicer to do in |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
After setting up an optimization function as following. What I would like to do as indicated at the point in the code is once the gradients have been applied is to clip a particular variable values to be in a certain range. get_params(opt_state) will get you a dictionary of the variable values that I can then clip, but I don't know how (or if it is even possible) to then take the updated dictionary of variables and overwrite the opt_state.
Beta Was this translation helpful? Give feedback.
All reactions