Skip to content

Commit df7535d

Browse files
authored
[gkd] fix qwen2.5-vl 3b/7b vocab_size mismatch (#5335)
1 parent dd89c52 commit df7535d

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

swift/trainers/rlhf_trainer/gkd_trainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import torch.nn as nn
10+
import torch.nn.functional as F
1011
from transformers import PreTrainedModel
1112
from trl import GKDTrainer as HFGKDTrainer
1213
from trl import SFTTrainer as HFSFTTrainer
@@ -104,6 +105,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
104105
shifted_student_logits = outputs_student.logits[mask][None]
105106
shifted_teacher_logits = outputs_teacher.logits[mask][None]
106107

108+
# Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct.
109+
stu_dim = shifted_student_logits.shape[-1]
110+
tea_dim = shifted_teacher_logits.shape[-1]
111+
if stu_dim < tea_dim:
112+
shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0)
113+
shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:]
114+
elif stu_dim > tea_dim:
115+
shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0)
116+
shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:]
117+
107118
# compute loss
108119
loss = self.generalized_jsd_loss(
109120
student_logits=shifted_student_logits,
@@ -133,8 +144,10 @@ def training_step(self,
133144
with unwrap_model_for_generation(
134145
model, self.accelerator,
135146
gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model:
147+
unwrapped_model.eval() # Remove the gradient_checkpointing warning.
136148
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
137149
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id)
150+
unwrapped_model.train()
138151
inputs['input_ids'] = new_input_ids
139152
inputs['attention_mask'] = new_attention_mask
140153
inputs['labels'] = new_labels

0 commit comments

Comments
 (0)