Replies: 1 comment
-
To access the learning rate during training you could:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
count = opt_state.inner_state[0].count # get current step
lr = schedule(count) # get learning rate from schedule
print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.9f}')
return params
params = fit(initial_params, optimizer)
# Wrap the optimizer to inject the hyperparameters
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
# Since we injected hyperparams, we can access them directly here
print(f'Available hyperparams: {" ".join(opt_state.hyperparams.keys())}\n')
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
# Get the updated learning rate
lr = opt_state.hyperparams['learning_rate']
print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.3f}')
return params
params = fit(initial_params, optimizer)
For further discussion, see #206 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a method for accessing the learning rate being used by the optimizer at each step during training from the schedule.py schedules?
Beta Was this translation helpful? Give feedback.
All reactions