Skip to content

Commit 33f7f48

Browse files
roycho96Sunghyun Cho
andauthored
Fix: fix ignore_index not being applied in JSD distillation loss (#974)
## Summary Fix `ignore_index` parameter not being applied in `LigerFusedLinearJSDLoss`. The `ignore_index` parameter was accepted but never used in `distillation_loss_fn`, causing all tokens (including padding/prompt) to be included in loss computation. ### Changes - Change `reduction='sum'` to `reduction='none'` for per-token masking - Use `masked_fill` for dtype preservation (prevent bf16 → fp32 promotion) - Add `clamp_min(1)` to prevent NaN when all tokens ignored - Normalize by `num_valid_tokens` instead of `full_target.shape[0]` - Add comprehensive ignore_index tests ## Testing Done - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Sunghyun Cho <[email protected]>
1 parent fe1ea95 commit 33f7f48

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,15 @@ def _compute_loss(
132132
)
133133
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
134134

135-
hard_loss /= full_target.shape[0]
135+
num_valid_tokens = (full_target != ignore_index).sum()
136+
num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero
136137

137-
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
138-
soft_loss /= full_target.shape[0]
138+
hard_loss /= num_valid_tokens
139+
140+
soft_loss = distillation_loss_fn(
141+
student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs
142+
)
143+
soft_loss /= num_valid_tokens
139144

140145
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
141146
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)

src/liger_kernel/chunked_loss/jsd_loss.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,50 @@
1111

1212
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
1313
@staticmethod
14-
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
14+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5, target=None, ignore_index=-100):
1515
"""
1616
Compute JSD loss (Jensen-Shannon Divergence Loss).
1717
Args:
1818
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
1919
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
2020
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
21+
target (torch.Tensor): Target labels for masking. Shape: (chunk_size,).
22+
ignore_index (int): Index to ignore in loss computation.
2123
Returns:
2224
torch.Tensor: Jensen-Shannon Divergence loss
25+
Note:
26+
- Uses reduction="none" to preserve per-token losses for masking
27+
- KL divergence requires summing over vocab dimension (not mean)
28+
- Masking excludes padding/prompt tokens from loss computation
2329
"""
2430
student_log_probs = F.log_softmax(student_logits, dim=-1)
2531
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
2632

2733
if beta == 0:
28-
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
34+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
2935
elif beta == 1:
30-
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
36+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
3137
else:
3238
# Compute probabilities (only required for mean calculation)
3339
log_mean_probs = torch.logsumexp(
3440
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
3541
)
3642

37-
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
38-
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
43+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True)
44+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True)
3945

4046
# JSD is the weighted average of the KL divergences
4147
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
42-
return jsd_loss
48+
49+
# Sum over vocab dimension (KL divergence definition)
50+
jsd_loss = jsd_loss.sum(dim=-1) # (chunk_size,)
51+
52+
# Apply ignore_index mask
53+
if target is not None:
54+
mask = target != ignore_index
55+
jsd_loss = jsd_loss.masked_fill(~mask, 0.0)
56+
57+
return jsd_loss.sum()
4358

4459
@classmethod
4560
def forward(

test/chunked_loss/test_jsd_loss.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ def __init__(
3737
temperature=temperature,
3838
)
3939

40-
def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
40+
def distillation_loss(self, student_logits, teacher_logits, target=None, ignore_index=-100, beta=0.5):
4141
"""
4242
Compute JSD loss (Jensen-Shannon Divergence Loss).
4343
Args:
44-
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
45-
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
44+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
45+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
46+
target (torch.Tensor): Target labels for masking. Shape: (batch_size * seq_len,).
47+
ignore_index (int): Index to ignore in loss computation.
4648
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
4749
Returns:
4850
torch.Tensor: Jensen-Shannon Divergence loss
@@ -55,17 +57,24 @@ def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
5557
elif beta == 1:
5658
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
5759
else:
58-
# Compute probabilities (only required for mean calculation)
5960
log_mean_probs = torch.logsumexp(
6061
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
6162
)
63+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True)
64+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True)
65+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
6266

63-
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
64-
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)
67+
# Sum over vocab dimension
68+
jsd_loss = jsd_loss.sum(dim=-1)
6569

66-
# JSD is the weighted average of the KL divergences
67-
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
68-
return jsd_loss
70+
# Apply ignore_index mask
71+
if target is not None:
72+
mask = target != ignore_index
73+
jsd_loss = jsd_loss * mask.float()
74+
num_valid_tokens = mask.sum().clamp_min(1)
75+
return jsd_loss.sum() / num_valid_tokens
76+
77+
return jsd_loss.sum()
6978

7079

7180
class TorchLMHeadJSD(torch.nn.Module):
@@ -182,6 +191,7 @@ def forward(self, student_input, teacher_input, target):
182191
(0.5, 1.0, 0.0, 0.2),
183192
],
184193
)
194+
@pytest.mark.parametrize("ignore_index", [-100, 42])
185195
def test_correctness(
186196
B,
187197
T,
@@ -196,6 +206,7 @@ def test_correctness(
196206
weight_hard_loss,
197207
weight_soft_loss,
198208
beta,
209+
ignore_index,
199210
):
200211
torch_lm_head_jsd = TorchLMHeadJSD(
201212
H=H,
@@ -207,6 +218,7 @@ def test_correctness(
207218
weight_hard_loss=weight_hard_loss,
208219
weight_soft_loss=weight_soft_loss,
209220
beta=beta,
221+
ignore_index=ignore_index,
210222
)
211223
liger_lm_head_jsd = LigerLMHeadJSD(
212224
H=H,
@@ -218,6 +230,7 @@ def test_correctness(
218230
weight_hard_loss=weight_hard_loss,
219231
weight_soft_loss=weight_soft_loss,
220232
beta=beta,
233+
ignore_index=ignore_index,
221234
)
222235

223236
torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
@@ -243,6 +256,11 @@ def test_correctness(
243256

244257
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
245258

259+
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
260+
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
261+
target[indices_to_assign] = ignore_index
262+
263+
# Assign some random number of elements as ignore_index
246264
loss1 = torch_lm_head_jsd(student_input1, teacher_input, target)
247265
loss2 = liger_lm_head_jsd(student_input2, teacher_input, target)
248266
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

test/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,9 @@ def get_batch_loss_metrics(
10331033
student_logits /= self.temperature
10341034
teacher_logits /= self.temperature
10351035

1036-
soft_loss = self.distillation_loss(student_logits, teacher_logits, **loss_kwargs)
1036+
soft_loss = self.distillation_loss(
1037+
student_logits, teacher_logits, target=target, ignore_index=self.ignore_index, **loss_kwargs
1038+
)
10371039
# full loss
10381040
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss
10391041
return loss

0 commit comments

Comments
 (0)