Replies: 1 comment 1 reply
-
From the expected shape, it sounds like you want the gradient with respect to the gx = jax.grad(loss, argnums=1)(params, x, y) If that's not what you have in mind, I'm unclear on where the batch size of 5 would come from. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
What is the correct way to compute per-sample gradients for an RNN? The per-sample grad in the jax documentation requires calling grad before vmap (i.e.,
vmap(grad(f))
), but in many cases recurrent models already contain calls tovmap
, so thisvmap(grad(f))
wrapping is not possible. See below for a working example:I would like
gb.shape
to be(5, 3)
, not3
, without rewritinglinear_rnn
.Beta Was this translation helpful? Give feedback.
All reactions