Compute gradient for a batch of inputs together with the params of network in JAX #8481
Unanswered
Behnam-Asadi
asked this question in
Q&A
Replies: 2 comments
-
https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad Sounds like you want to take the gradient to only certain inputs -- is argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). |
Beta Was this translation helpful? Give feedback.
0 replies
-
This helped me train the embedding vectors (probably in an subefficient way) but I am looking a pre-define module like torch.nn.Embedding in PyTorch to train the embeddings in a more efficient way. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
As I am new to JAX it might be a naive question but what is the best way to update a batch of input using grad_and_value function? In my case inputs of the network are learnable embedded vectors and as in each iteration, the input is a batch of these vectors I need to update params of the network together with the input batch (a batch of embedded vectors). I think I can define the whole vectors as params of the network but this way grad_and_value would compute gradient with respect to all vectors instead of just a specific batch.
Beta Was this translation helpful? Give feedback.
All reactions