Replies: 1 comment
-
Hi - thanks for the question. The short answer is there's not really any easy way to do this; see the discussion at #1563 for some background. You might be able to make some progress using something like https://github.com/mfschubert/sparsejac, but I'm not aware of any concrete examples. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a way to get the diagonal of the jacobian matrix in jax, without performing multiple forward passes or computing the full jacobian?
More specifically, I have a function with an input shape BxNxD1 and an output shape BxNxD2. I can of course do a vmap over the batch dimension B, but not over the set dimension N.
I want to be able to get the derivative of the ith element in the output set with respect to the ith element of the input set. The only way I could think of is to unroll the input and pass the NxD1 elements as separate arguments to an auxiliary function. This function is defined such that it returns the ith element of the output. I could then take the jacobian of that function by specifying the argument. However, this would require multiple forward passes. Is there any way to avoid this?
Beta Was this translation helpful? Give feedback.
All reactions