Skip to content
Discussion options

You must be logged in to vote

Oh, in this case, you can do something like this

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.custom_vjp, nondiff_argnums=(3,))
def f(x, y, z, needs_input_grad=(True, True, True)):
    return x + y + z


def f_fwd(x, y, z, needs_input_grad=(True, True, True)):
    res = (x, y, z)
    return x + y + z, res

def f_bwd(needs_input_grad, res, g):    
    x, y, z = res
    dx = dy = dz = None
    if needs_input_grad[0]:
        dx = jnp.ones_like(x)
    if needs_input_grad[1]:
        dy = jnp.ones_like(y)
    if needs_input_grad[2]:
        dz = jnp.ones_like(z)
    
    return dx, dy, dz

f.defvjp(f_fwd, f_bwd)


x = y = z = jnp.array(1.)

print(jax.grad(f, a…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@amifalk
Comment options

@anh-tong
Comment options

Answer selected by amifalk
@amifalk
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants