Skip to content

Commit 79e1063

Browse files
authored
Add zl
1 parent d54fe0a commit 79e1063

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

fms_fsdp/utils/train_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def train(
8989
output = output.logits if hasattr(output, "logits") else output
9090
ce_loss = torch.nn.CrossEntropyLoss()
9191
loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long())
92+
loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean()
9293

9394
loss.backward()
9495
ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()

0 commit comments

Comments
 (0)