-
Hi everyone, I am trying to learn automatic differentiation better. I was looking at the implicit differentiation section of the official docs. The example given is clear and informative. However, the code associated with it is a bit hard to digest for me. Therefore, I wrote an alternate code for the same
This gives the same result for the first and second order gradient
I have three questions regarding this:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, This may not completely help you but here is my thought.
I hope this helps clarify the concept for you. |
Beta Was this translation helpful? Give feedback.
Hi,
This may not completely help you but here is my thought.
I believe the code is correct, (my justification for this follows)
To understand what$x=f(a,x)$ (after solving the fixed point equation). We get the output and denote the ouput $F: = f(a, x)$ . This output will be used to compute a loss, let say $\mathcal{L}$ . Now we perform backward from the loss to $F$ and now we obtain $\frac{\partial \mathcal L}{\partial F}$ . Mathematically, $\frac{\partial \mathcal L}{\partial F}$ as the initial value for solving fixed point of backward pass (using custom VJP instead of JAX autodiff)
x_start_bar
is, we may need take a detour. We begin withx_star_bar
isIf you find this explanation is s…