Selective Multi-task Gradient #9989
-
Say I have a backbone and two heads, and I only want to backpropagate gradient from one of those two heads. A straightforward way to do this in PyTorch is simply to forward the backbone and detach its output in one of the two heads. However, I cannot think of any efficient way to do this in JAX, without having to forward twice the backbone, one time for each head |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can use @partial(jax.jit, static_argnums=0)
def train_step(stop_grad_head0, params, x, y0, y1):
def f(params):
z = backbone(params['backbone'], x)
h0 = head0(params['head0'], z)
h1 = head1(params['head1'], z)
if stop_grad_head0:
return loss_fn(y0, y1, jax.lax.stop_gradient(h0), h1)
return loss_fn(y0, y1, h0, jax.lax.stop_gradient(h1)) # stop grad head 1
loss, grad = jax.value_and_grad(f)(params)
return loss, apply_grad(params, grad) def f(params, k: float):
# k \in [0, 1]
z = backbone(params['backbone'], x)
h0 = head0(params['head0'], z)
h0 = (1 - k) * h0 + k * jax.lax.stop_gradient(h0)
h1 = head1(params['head1'], z)
h1 = k * h1 + (1 - k) * jax.lax.stop_gradient(h1)
return loss_fn(y0, y1, h0, h1) |
Beta Was this translation helpful? Give feedback.
You can use
jax.lax.stop_gradient
as JAX'sdetach
.Note that will either need to compile train_step function twice with static args indicate the selection, or need to perform an all zero bp.