Skip to content

Commit a089cd5

Browse files
authored
Fix nan loss error for LigerFusedLinearJSDLoss (#862)
## Summary Fixes #769 As described in the issue, I have updated the code to fix the nan error. Could you please review? cc: @shimizust ## Details - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 62a9054 commit a089cd5

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/liger_kernel/chunked_loss/jsd_loss.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
import torch.nn.functional as F
35

@@ -25,8 +27,9 @@ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
2527
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
2628
else:
2729
# Compute probabilities (only required for mean calculation)
28-
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29-
log_mean_probs = mean_probs.log()
30+
log_mean_probs = torch.logsumexp(
31+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
32+
)
3033

3134
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
3235
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)

test/chunked_loss/test_jsd_loss.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import pytest
24
import torch
35
import torch.nn.functional as F
@@ -54,8 +56,9 @@ def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
5456
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
5557
else:
5658
# Compute probabilities (only required for mean calculation)
57-
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
58-
log_mean_probs = mean_probs.log()
59+
log_mean_probs = torch.logsumexp(
60+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
61+
)
5962

6063
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
6164
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)

0 commit comments

Comments
 (0)