We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d54fe0a commit 79e1063Copy full SHA for 79e1063
fms_fsdp/utils/train_utils.py
@@ -89,6 +89,7 @@ def train(
89
output = output.logits if hasattr(output, "logits") else output
90
ce_loss = torch.nn.CrossEntropyLoss()
91
loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long())
92
+ loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean()
93
94
loss.backward()
95
ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()
0 commit comments