Independent vector-Jacobian product for multiple arguments #8571
Replies: 2 comments 3 replies
-
Can anyone help me? Thanks! |
Beta Was this translation helpful? Give feedback.
-
Dear @hawkinsp , thanks for your reply! I appologize for my unclear question. Actually, I have three equations: Hope you can see the image, If not, there are the latex codes: I am trying to get the derivative: from functools import partial
import jax.numpy as jnp
from jax import random
from jax import vjp
# Vector-valued gradients with VJPs
# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-valued-gradients-with-vjps
def vgrad(f, xs):
ys, vjp_fn = vjp(f, *xs)
return vjp_fn(tuple([jnp.ones(y.shape) for y in ys]))
ENa = 55.
gNa = 35.
def f1(V, m, h):
return - gNa * m ** 3 * h * (V - ENa),
def f2(m, V):
a = 1 / (1 - jnp.exp(-(V + 40) / 10))
b = 4 * jnp.exp(-(V + 65) / 18)
return a * m + b * (1 - m),
def f3(h, V):
a = 0.07 * jnp.exp(-(V + 65) / 20)
b = 1 / (1 + jnp.exp(-(V + 35) / 10))
return a * h + b * (1 - h),
def F(V, m, h):
f1v = f1(V, m, h)[0]
f2v = f2(m, V)[0]
f3v = f3(h, V)[0]
return f1v, f2v, f3v When I simultaneously evaluate the function key = random.PRNGKey(0)
key, k1, k2, k3 = random.split(key, 4)
ex_V = random.uniform(k1, (3,))
ex_m = random.uniform(k2, (3,))
ex_h = random.uniform(k3, (3,))
D = vgrad(F, (ex_V, ex_m, ex_h))
print(D[0])
# [-0.4800583 -9.6292515 -4.302372 ]
print(D[1])
# [ 249.06578 2387.181 911.52954]
print(D[2])
# [ 59.27976 564.16797 904.1445 ] However, what I want to are the gradient values of: print(vgrad(partial(f1, m=ex_m, h=ex_h), (ex_V,))[0])
# [-0.47693872 -9.626045 -4.301631 ]
print(vgrad(partial(f2, V=ex_V), (ex_m,))[0])
# [0.9116234 0.9108186 0.91096807]
print(vgrad(partial(f3, V=ex_V), (ex_h,))[0])
# [-0.9687328 -0.96815366 -0.9682618 ] Obviously, the gradients are significantly different by using the separate and merged functions. There are many reasons I must code these three functions Could you give me a solution? Thank you very very much!! |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
In the official documentation, JAX supports
vjp
for multiple arguments:Once the return values are using both arguments:
However, we want to get the derivative for each argument independently. That is to say, the except derivatives are:
Can anyone tell me how to make the excepted gradients? Thanks!
Note here we cannot use
custom_vjp
because the do not explicitly know the logic of the provided functions.Beta Was this translation helpful? Give feedback.
All reactions