With the (somewhat) recent changes to how jax handles custom VJPs, it is now possible to define derivatives using the function for which we are defining the derivative. Since the VJP for the fixed point can be computed by solving for another fixed point, a recursive implementation would allow higher order differentiation with little to no extra complexity.