access the learning rate of an (arbitrary) optax optimizer #961
Unanswered
fabianp
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
The optimizer needs to be wrapped in
optax.inject_hyperparams
to be able to access the learning rate from the state.For example if your optimizer is defined
asopt = optax.sgd(learning_rate=some_schedule)
, then you need to replace itbyopt = optax.inject_hyperparams(optax.sgd)(learning_rate=some_schedule)
Say that your optimizer is then used as
then you can access the learning rate as
rather than accessing the learning rate as state.hyperpararams['learning_rate'] you may also use the handy optax.tree_utils.tree_get that can fetch any element of a state as
learning_rate = optax.tree_utils.tree_get(state, 'learning_rate')
(you still need the learning rate to present in the state so you need to have defined it through optax.inject_hyperparams).You may have to specify that you are searching for a scalar in the state, so you may need to use
learning_rate = optax.tree_utils.tree_get( state, 'learning_rate', filtering=lambda path, value: isinstance(value, jnp.ndarray) )
. If you still get errors using tree_get, try using tree_get_all_with_path to see all entries in the state that are called 'learning_rate'.See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.tree_utils.tree_get for the documentation of
tree_get
The tree_get logic may streamline the access of the learning rate of an optimizer defined through a chain of transformations.
Beta Was this translation helpful? Give feedback.
All reactions