Best way to "mask" gradients in varying length sequences #10563
Unanswered
valentinmace
asked this question in
Q&A
Replies: 1 comment
-
IIUC, in addition to set the labels to 0, you should set model outputs of the last 400 timesteps to 0 as well. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a batch of predictions from my transformer model and a batch of labels.
Say both predictions and labels have the shape (256, 1000, 32):
I am using a standard L2 loss to fit predictions and labels. However, in a given trajectory, the agent might die before the end of an episode, let's say the 600th timestep, leading to 400 timesteps on this trajectory that are irrelevant.
So I have 600 "good" predictions and labels and 400 "irrelevant" that I don't want to affect my model during loss computation and backpropagation. What is the best way to ignore these in jax so that only good data is considered ?
Am I wrong thinking that putting both labels and data of the last 400 timesteps to 0 will provide an error of 0 and not affect my training ?
Thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions