Best practices for checkpointing under grad? #16804
Unanswered
paulbricman
asked this question in
Q&A
Replies: 1 comment
-
Wrap your |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Thanks for the awesome module.
I'm working on a project involving meta-learning, where I have an inner optimization loop and an outer one. In the outer one, I'm essentially evaluating
jax.value_and_grad(inner_loop)
and updating the outer loop parameters based on the resulting gradients.However, I'd like to checkpoint the parameters being optimized as part of the inner loop. If I try this, I run into the following error:
Unfortunately, while they are located outside the grad computation of the inner loop, they are located inside the grad computation of the outer one. In the docs, I learn that "Like pure_callback, io_callback fails under automatic differentiation if it is passed a differentiated variable."
In this context, what is the best way for doing an io_callback from inside the grad of the outer loop, while still using a differentiated variable as an argument (for checkpointing)?
Potential options:
Beta Was this translation helpful? Give feedback.
All reactions