Efficient computation of per-example model parameter gradients. #8264
Unanswered
JulienSiems
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey everyone!
I would like to obtain the individual mini-batch gradients for each example in a mini-batch in jax.
In PyTorch this is only possible with some hacks/packages such as autograd-hacks or backpack.
I saw the example in the docs explaining how to compute the derivative with respect to each input batch element, but I noticed that in order to obtain these gradients the example used the loss function in a per-example way. I was wondering whether there is a way to get the per example parameter gradients directly via reverse mode autodiff based on the averaged loss function?
This would allow obtaining the individual mini-batch gradients during e.g. a normal training routine without having to recompute the backward pass.
Thanks a lot!
Julien
Beta Was this translation helpful? Give feedback.
All reactions