File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed
src/liger_kernel/chunked_loss Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -115,6 +115,21 @@ def _compute_loss(
115115 student_logits_chunk /= temperature
116116 teacher_logits_chunk /= temperature
117117
118+ # If the teacher and student token size is different, pad student logits to match the teacher's.
119+ # This only applies to cases where they share exactly the same vocab and tokenizer just
120+ # that teacher logit is padded for some training efficiency such as
121+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
122+ teacher_vocab_size = teacher_weight .shape [0 ]
123+ student_vocab_size = student_weight .shape [0 ]
124+ if teacher_vocab_size > student_vocab_size :
125+ pad_size = teacher_vocab_size - student_vocab_size
126+ pad_tensor = torch .zeros (
127+ (* student_logits_chunk .shape [:- 1 ], pad_size ),
128+ dtype = student_logits_chunk .dtype ,
129+ device = student_logits_chunk .device
130+ )
131+ student_logits_chunk = torch .cat ([student_logits_chunk , pad_tensor ], dim = - 1 )
132+
118133 hard_loss /= full_target .shape [0 ]
119134
120135 soft_loss = distillation_loss_fn (student_logits_chunk , teacher_logits_chunk , ** loss_kwargs )
You can’t perform that action at this time.
0 commit comments