Skip to content
Discussion options

You must be logged in to vote
params = jnp.array([0.1, 0.2])

@jax.custom_vjp
def wrapped_exec(params):
    y = params ** 2, params ** 3
    # don't need compute jacs here
    return y

def wrapped_exec_fwd(params):
    y = wrapped_exec(params)
    jacs = jnp.diag(2 * params), jnp.diag(3 * params ** 2) # compute here
    return y, jacs # don't need params here

def wrapped_exec_bwd(res, g):
    jac1, jac2 = res
    g1, g2 = g
    return (g1 @ jac1) + (g2 @ jac2),

wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd)
jax.jacobian(wrapped_exec)(params)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@antalszava
Comment options

Answer selected by antalszava
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants