Skip to content
Discussion options

You must be logged in to vote

Hi,

This may not completely help you but here is my thought.

  1. I believe the code is correct, (my justification for this follows)

  2. To understand what x_start_bar is, we may need take a detour. We begin with $x=f(a,x)$ (after solving the fixed point equation). We get the output and denote the ouput $F: = f(a, x)$. This output will be used to compute a loss, let say $\mathcal{L}$. Now we perform backward from the loss to $F$ and now we obtain $\frac{\partial \mathcal L}{\partial F}$. Mathematically, x_star_bar is $\frac{\partial \mathcal L}{\partial F}$ as the initial value for solving fixed point of backward pass (using custom VJP instead of JAX autodiff)
    If you find this explanation is s…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@SNMS95
Comment options

Answer selected by SNMS95
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants