Skip to content

Commit 9120a17

Browse files
tastelikefeettastelikefeet
andauthored
fix sft/ulysses eval (#4494)
Co-authored-by: tastelikefeet <[email protected]>
1 parent f41e7ba commit 9120a17

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None
118118
else:
119119
logits = outputs
120120
device = logits.device
121+
if labels.shape[1] > logits.shape[1]:
122+
_, _, labels, _, _, loss_scale = ulysses.pad_and_split_inputs(None, None, labels, None, None, loss_scale)
121123
logits = logits.view(-1, logits.shape[-1])
122-
_, _, labels, _, _, loss_scale = ulysses.pad_and_split_inputs(None, None, labels, None, None, loss_scale)
123124

124125
labels = labels.flatten().to(device)
125126
sploss_parallel_size = int(os.environ.get('CELOSS_PARALLEL_SIZE', '0'))
@@ -142,7 +143,7 @@ def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None
142143

143144

144145
@profiling_decorator
145-
def _prepare_inputs(self, generation_batch):
146+
def _prepare_inputs_grpo(self, generation_batch):
146147
ulysses = self.ulysses
147148
mode = 'train' if self.model.training else 'eval'
148149
if mode == 'train':
@@ -159,6 +160,14 @@ def _prepare_inputs(self, generation_batch):
159160
return inputs
160161

161162

163+
def _prepare_inputs(self, inputs, ulysses):
164+
if 'labels' in inputs:
165+
labels = inputs['labels']
166+
_, _, labels, _, _, _ = ulysses.pad_and_split_inputs(None, None, labels, None, None, None)
167+
inputs['labels'] = labels
168+
return self._origin_prepare_inputs(inputs)
169+
170+
162171
def old_policy(self):
163172
ulysses = self.ulysses
164173
# changes: `* ulysses.sp_world_size`
@@ -171,7 +180,8 @@ def get_per_token_logps(self,
171180
logits: torch.FloatTensor,
172181
labels: torch.LongTensor,
173182
ulysses=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
174-
_, _, labels, _, _, _ = ulysses.pad_and_split_inputs(None, None, labels, None, None, None)
183+
if labels.shape[1] > logits.shape[1]:
184+
_, _, labels, _, _, _ = ulysses.pad_and_split_inputs(None, None, labels, None, None, None)
175185
loss_mask = labels != self.label_pad_token_id
176186
labels = labels.clone() # No need to shift, pad and split has shifted the inputs.
177187
labels[~loss_mask] = 0
@@ -823,9 +833,13 @@ def prepare_trainer(self, trainer):
823833

824834
trainer.ulysses = self
825835
if trainer.__class__.__name__ == 'Seq2SeqTrainer':
836+
trainer._origin_prepare_inputs = trainer._prepare_inputs
837+
trainer._prepare_inputs = MethodType(partial(_prepare_inputs, ulysses=self), trainer)
826838
trainer.compute_loss_func = partial(loss_scale_sp_func, ulysses=self)
827839

828840
elif trainer.__class__.__name__ == 'DPOTrainer':
841+
trainer._origin_prepare_inputs = trainer._prepare_inputs
842+
trainer._prepare_inputs = MethodType(partial(_prepare_inputs, ulysses=self), trainer)
829843
trainer.get_per_token_logps = MethodType(partial(get_per_token_logps, ulysses=self), trainer)
830844

831845
def rlhf_loss_scale_sp_func(_, *args, **kwargs):
@@ -838,7 +852,7 @@ def rlhf_loss_scale_sp_func(_, *args, **kwargs):
838852
trainer.ulysses = self
839853
trainer.args.gradient_accumulation_steps = trainer.args.gradient_accumulation_steps * self.sp_world_size
840854
trainer.old_policy = MethodType(old_policy, trainer)
841-
trainer._prepare_inputs = MethodType(_prepare_inputs, trainer)
855+
trainer._prepare_inputs = MethodType(_prepare_inputs_grpo, trainer)
842856
trainer._get_per_token_logps = MethodType(_get_per_token_logps, trainer)
843857
trainer.split_by_mini_batches = MethodType(split_by_mini_batches, trainer)
844858

@@ -852,7 +866,8 @@ def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]:
852866
preds = torch.from_numpy(preds).to(get_current_device())
853867
if isinstance(labels, np.ndarray):
854868
labels = torch.from_numpy(labels).to(get_current_device())
855-
_, _, labels, _, _, _ = self.pad_and_split_inputs(None, None, labels, None, None, None)
869+
if labels.shape[1] > preds.shape[1]:
870+
_, _, labels, _, _, _ = self.pad_and_split_inputs(None, None, labels, None, None, None)
856871
shape0 = preds.shape[0]
857872
preds_output = torch.empty((shape0 * self.sp_world_size, preds.shape[1]),
858873
dtype=preds.dtype,

swift/trainers/trainers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
229229
if getattr(self.args, 'average_tokens_across_devices', False) and self.model_accepts_loss_kwargs:
230230
loss *= self.accelerator.num_processes
231231

232-
if outputs.logits is not None and labels is not None:
232+
if outputs.logits is not None and labels is not None and not return_outputs:
233233
# Liger does not have logits
234234
self._compute_acc(outputs, labels)
235235
return (loss, outputs) if return_outputs else loss

0 commit comments

Comments
 (0)