|
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 | import torch.nn as nn
|
| 10 | +import torch.nn.functional as F |
10 | 11 | from transformers import PreTrainedModel
|
11 | 12 | from trl import GKDTrainer as HFGKDTrainer
|
12 | 13 | from trl import SFTTrainer as HFSFTTrainer
|
@@ -104,6 +105,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
|
104 | 105 | shifted_student_logits = outputs_student.logits[mask][None]
|
105 | 106 | shifted_teacher_logits = outputs_teacher.logits[mask][None]
|
106 | 107 |
|
| 108 | + # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct. |
| 109 | + stu_dim = shifted_student_logits.shape[-1] |
| 110 | + tea_dim = shifted_teacher_logits.shape[-1] |
| 111 | + if stu_dim < tea_dim: |
| 112 | + shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0) |
| 113 | + shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:] |
| 114 | + elif stu_dim > tea_dim: |
| 115 | + shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0) |
| 116 | + shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:] |
| 117 | + |
107 | 118 | # compute loss
|
108 | 119 | loss = self.generalized_jsd_loss(
|
109 | 120 | student_logits=shifted_student_logits,
|
@@ -133,8 +144,10 @@ def training_step(self,
|
133 | 144 | with unwrap_model_for_generation(
|
134 | 145 | model, self.accelerator,
|
135 | 146 | gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model:
|
| 147 | + unwrapped_model.eval() # Remove the gradient_checkpointing warning. |
136 | 148 | new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
137 | 149 | unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id)
|
| 150 | + unwrapped_model.train() |
138 | 151 | inputs['input_ids'] = new_input_ids
|
139 | 152 | inputs['attention_mask'] = new_attention_mask
|
140 | 153 | inputs['labels'] = new_labels
|
|
0 commit comments