Replies: 2 comments 1 reply
-
You might find this response helpful. You'll likely want to pass in a tuple containing information about the gradients you wish to compute as a non-differentiable argument. |
Beta Was this translation helpful? Give feedback.
1 reply
-
One thing to keep in mind is that the XLA compiler does automatic dead code elimination, so as long as you wrap your computation in |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Dear developers,
I have a function which takes two arguments, e.g.,
But if I only need to compute the derivative with respect to one of the arguments, is there any way to tell the backward pass not to compute the VJP for the other argument? I encounter this problem in implicit differentiation, where one of the arguments may be closed over, yet both VJPs are computed and it introduces a lot of overhead. It would be nice to know which VJP is needed before the backward pass, e.g., after the forward pass, I get a
flag
and then maybe I can define the backward pass like the followingThanks
Beta Was this translation helpful? Give feedback.
All reactions