How to define the jacobian of a vector-valued func with custom_vjp where the jacobian is computed on the forward pass? #10291
Answered
by
YouJiacheng
antalszava
asked this question in
Q&A
Replies: 1 comment 1 reply
-
params = jnp.array([0.1, 0.2])
@jax.custom_vjp
def wrapped_exec(params):
y = params ** 2, params ** 3
# don't need compute jacs here
return y
def wrapped_exec_fwd(params):
y = wrapped_exec(params)
jacs = jnp.diag(2 * params), jnp.diag(3 * params ** 2) # compute here
return y, jacs # don't need params here
def wrapped_exec_bwd(res, g):
jac1, jac2 = res
g1, g2 = g
return (g1 @ jac1) + (g2 @ jac2),
wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd)
jax.jacobian(wrapped_exec)(params) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
antalszava
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone, given a function that computes both the value and the jacobian of the function on the forward pass, what could be a way to register the jacobian when using
custom_vjp
?The following is an attempt at passing the jacobian to
wrapped_exec_bwd
, however, it fails due to output issues:The traceback contains the
vmap
call, although not sure how to make the vectorization be compatible with the computed jacobian.Beta Was this translation helpful? Give feedback.
All reactions