Skip to content
Discussion options

You must be logged in to vote

Good question, as always!

One way to reduce redundant work here (without relying on jax.jit to do common-subexpression elimination) is to use jax.vjp directly (around which grad is a thin wrapper):

(loss_a, loss_b), f_vjp = jax.vjp(f, v, w)
v_bar, _ = f_vjp(jnp.ones_like(loss_a), jnp.zeros_like(loss_b))
_, w_bar = f_vjp(jnp.zeros_like(loss_a), jnp.ones_like(loss_b))

That will only run the forward pass once (on the line which calls jax.vjp) and then run two separate backward passes.

Another variant would be just to call jax.jacrev(f)(v, w), then extract the entries of the result that you want (corresponding to v-input-a-output and w-input-b-output). That'll also run the forward pass just o…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mattjj
Comment options

@NeilGirdhar
Comment options

Answer selected by NeilGirdhar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants