-
Hi, I'm having a hard time understanding part of the syntax of from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in f_fwd
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd) For the forward and backward functions, what are the "residuals" and how are they defined? I see that they are equal to |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
From my understanding, the residuals are simply any values that you want to cache during your forward pass. |
Beta Was this translation helpful? Give feedback.
From my understanding, the residuals are simply any values that you want to cache during your forward pass.
These cached values are used in the backward pass to save computational time (although you do loose memory in saving these results!).