Is there a way to calculate cotangents with respect to particular outputs? #12267
-
Suppose you have some opaque, expensive function: def f(v, w) -> Scalar, Scalar:
return loss_a, loss_b Given some def f_a(v_variable):
return f(v_variable, w)[0]
def f_b(w_variable):
return f(v, w_variable)[1]
v_bar = grad(f_a)(v)
w_bar = grad(f_b)(w) (This is just typed by intuition, so apologies if I've made an error.) Is it possible to only call |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Good question, as always! One way to reduce redundant work here (without relying on (loss_a, loss_b), f_vjp = jax.vjp(f, v, w)
v_bar, _ = f_vjp(jnp.ones_like(loss_a), jnp.zeros_like(loss_b))
_, w_bar = f_vjp(jnp.zeros_like(loss_a), jnp.ones_like(loss_b)) That will only run the forward pass once (on the line which calls Another variant would be just to call WDYT? |
Beta Was this translation helpful? Give feedback.
Good question, as always!
One way to reduce redundant work here (without relying on
jax.jit
to do common-subexpression elimination) is to usejax.vjp
directly (around whichgrad
is a thin wrapper):That will only run the forward pass once (on the line which calls
jax.vjp
) and then run two separate backward passes.Another variant would be just to call
jax.jacrev(f)(v, w)
, then extract the entries of the result that you want (corresponding to v-input-a-output and w-input-b-output). That'll also run the forward pass just o…