question about the vjp and grad #12467
Answered
by
soraros
yiminghwang
asked this question in
Q&A
-
Hi, there I am a beginner in using Jax, I follow the tutorial https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vjps-in-jax-code, and have a question about vjp and grad, def func(params):
return params[0] ** 2 + 2 * params[1]
params = jnp.asarray([1.,2.])
g1 = jax.grad(func)(params)
y, f_vjp = jax.vjp(func, params)
g2 = f_vjp(jnp.ones(2,dtype=jnp.dtype(y))[0]) the output of g1 and g2 might be the same. |
Beta Was this translation helpful? Give feedback.
Answered by
soraros
Sep 22, 2022
Replies: 1 comment 1 reply
-
|
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
yiminghwang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
grad
is implemented usingvalue_and_grad
, which is in turn implemented in terms ofvjp
.