Skip to content

Commit 3478bdb

Browse files
authored
[seq_parallel] fix sp compute_acc (#4456)
1 parent e060ad8 commit 3478bdb

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(self,
266266
self.model_accepts_loss_kwargs = False
267267
self.padding_free = self.template.padding_free
268268
self.template.padding_free = False
269+
self.template._packing = False
269270
for i, reward_func in enumerate(self.reward_funcs):
270271
if isinstance(reward_func, PreTrainedModel):
271272
if self.is_deepspeed_enabled:

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def rlhf_loss_scale_sp_func(_, *args, **kwargs):
845845
compute_acc_origin = metric.compute_acc
846846

847847
def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]:
848-
848+
_, _, labels, _, _, _ = self.pad_and_split_inputs(None, None, labels, None, None, None)
849849
# Gather preds and labels across the sp group
850850
if isinstance(preds, np.ndarray):
851851
preds = torch.from_numpy(preds).to(get_current_device())

0 commit comments

Comments
 (0)