Skip to content
Discussion options

You must be logged in to vote

You can use jax.lax.stop_gradient as JAX's detach.
Note that will either need to compile train_step function twice with static args indicate the selection, or need to perform an all zero bp.

@partial(jax.jit, static_argnums=0)
def train_step(stop_grad_head0, params, x, y0, y1):
    def f(params):
        z = backbone(params['backbone'], x)
        h0 = head0(params['head0'], z)
        h1 = head1(params['head1'], z)
        if stop_grad_head0:
            return loss_fn(y0, y1, jax.lax.stop_gradient(h0), h1)
        return loss_fn(y0, y1, h0, jax.lax.stop_gradient(h1)) # stop grad head 1
    loss, grad = jax.value_and_grad(f)(params)
    return loss, apply_grad(params, grad)
def f(params, k

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by linyuhongg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants