Skip to content

Commit 87187b1

Browse files
authored
[chunked loss] align teacher and student logit shape (#634)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> In rare cases where the teacher and student models don't have the same vocab size (but their vocabs are actually the same), for example qwen models, we pad students to match the teacher's logit. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> make test <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 3a5845b commit 87187b1

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,21 @@ def _compute_loss(
115115
student_logits_chunk /= temperature
116116
teacher_logits_chunk /= temperature
117117

118+
# If the teacher and student token size is different, pad student logits to match the teacher's.
119+
# This only applies to cases where they share exactly the same vocab and tokenizer just
120+
# that teacher logit is padded for some training efficiency such as
121+
# https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
122+
teacher_vocab_size = teacher_weight.shape[0]
123+
student_vocab_size = student_weight.shape[0]
124+
if teacher_vocab_size > student_vocab_size:
125+
pad_size = teacher_vocab_size - student_vocab_size
126+
pad_tensor = torch.zeros(
127+
(*student_logits_chunk.shape[:-1], pad_size),
128+
dtype=student_logits_chunk.dtype,
129+
device=student_logits_chunk.device,
130+
)
131+
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132+
118133
hard_loss /= full_target.shape[0]
119134

120135
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)

0 commit comments

Comments
 (0)