Skip to content

Commit 65b8c62

Browse files
committed
align teacher and student logit shape
1 parent 3a5845b commit 65b8c62

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,21 @@ def _compute_loss(
115115
student_logits_chunk /= temperature
116116
teacher_logits_chunk /= temperature
117117

118+
# If the teacher and student token size is different, pad student logits to match the teacher's.
119+
# This only applies to cases where they share exactly the same vocab and tokenizer just
120+
# that teacher logit is padded for some training efficiency such as
121+
# https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
122+
teacher_vocab_size = teacher_weight.shape[0]
123+
student_vocab_size = student_weight.shape[0]
124+
if teacher_vocab_size > student_vocab_size:
125+
pad_size = teacher_vocab_size - student_vocab_size
126+
pad_tensor = torch.zeros(
127+
(*student_logits_chunk.shape[:-1], pad_size),
128+
dtype=student_logits_chunk.dtype,
129+
device=student_logits_chunk.device
130+
)
131+
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132+
118133
hard_loss /= full_target.shape[0]
119134

120135
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)

0 commit comments

Comments
 (0)