How do functions with multiple outputs compute gradients? #9743
-
Hello, JAX team. import jax
import jax.numpy as jnp
from jax import value_and_grad, grad, jvp, vjp
from numpy import float32
def fun_multi_outputs(xPhys):
outarray = jnp.zeros(2)
output_A = jnp.sum(xPhys)
output_B = jnp.sum(xPhys ** 2)
outarray = outarray.at[0].set(output_A)
outarray = outarray.at[1].set(output_B)
print('funA is running')
return outarray
if __name__ == '__main__':
case = 'A'
if case == 'A':
inputs = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float32)
fun_outputs, vjp_fun = vjp(fun_multi_outputs, inputs)
v_vjp_A = jnp.array([1, 0], dtype=float32)
output_A_grad = vjp_fun(v_vjp_A)[0]
print("output_A's grad is: ", output_A_grad)
v_vjp_B = jnp.array([0, 1], dtype=float32)
output_B_grad = vjp_fun(v_vjp_B)[0]
print("output_B's grad is: ", output_B_grad) Is there a more direct and efficient way to implement it? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
You can use |
Beta Was this translation helpful? Give feedback.
You can use
jax.jacrev
andjax.jacfwd
.jax.jacrev
can be considered as vmappedvjp
. It is basically the same as your code, but do it for all outputs in parallel.jax.jacfwd
can be considered as vmappedjvp
.I recommend you to read #47 (comment).
And https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html.