Efficiently differentiate through lax.scatter ops? #13051
Unanswered
stefanozampini
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Consider this code, that computes the gradient wrt a subset of parameters.
If we inspect the output, we can see that the gradient computation of
subd
actually skips some computations, great!However, when I use the index formulation, the full gradient is computed in between scatter/gather ops. Is there a way to achieve the same goal (i.e., eliminate useless computations in gradients of subfunctions) using indices?
Beta Was this translation helpful? Give feedback.
All reactions