Replies: 1 comment 5 replies
-
def f(y, x):
return jnp.sum(y) + jnp.sum(x) # just example
def g(y, x, u):
f_x_func = jax.grad(f, argnums=1) # autodiff for x only
f_x, f_yx_dot_u = jax.jvp(lambda _y: f_x_func(_y, x), (y,), u) # use closure to autodiff for y only
# \frac{\partial f(y,x)}}{\partial x} and \frac{\partial^2 f(y,x)}}{\partial y\partial x} u
def h(y, x, u):
f_x_func = jax.grad(f, argnums=1) # autodiff for x only
f_x, f_yx_func = jax.linearize(lambda _y: f_x_func(_y, x), y) # use closure to autodiff for y only
return f_yx_func(f_x - u) |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a scalar function

I am interested in the following term

My questions as follows:
Beta Was this translation helpful? Give feedback.
All reactions