Skip to content

Commit 58abf66

Browse files
authored
fix grpo multi turn tp (#3837)
Co-authored-by: hjh <[email protected]>
1 parent b9ce786 commit 58abf66

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]:
762762
outputs = []
763763
outputs = gather_object(outputs)
764764
if self.args.tensor_parallel_size > 1:
765-
outputs = [item for output in outputs for item in output]
765+
outputs = [[item] for output in outputs for item in output]
766766
outputs = self.reorder_outputs(outputs, distributed_idx)
767767
if self.args.sleep_level > 0 and self.infer_rank >= 0:
768768
self.engine.engine.sleep(level=self.args.sleep_level)

0 commit comments

Comments
 (0)