Skip to content

Commit 6238a32

Browse files
authored
fix grpo async generate (#3829)
* fix * fix * fix * revert fast infer --------- Co-authored-by: hjh <[email protected]>
1 parent 59a3863 commit 6238a32

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -683,15 +683,15 @@ def _prefetch(self, dataloader):
683683
if self.accelerator.num_processes > 1:
684684
self.accelerator.wait_for_everyone()
685685

686-
def _fast_infer(self, all_inputs, inputs=None):
686+
def _fast_infer(self, inputs):
687687
"""
688688
This function performs fast inference by managing model and optimizer offloading,
689689
loading weights if necessary, distributing inputs among workers, and generating
690690
completions using the vLLM/LMDeploy framework. It supports both synchronous and asynchronous
691691
inference modes.
692-
all_inputs: all gather inputs in distributed
693692
inputs: local inputs
694693
"""
694+
695695
if self.args.sleep_level > 0 and self.infer_rank >= 0:
696696
if self.args.offload_model:
697697
self.offload_model()
@@ -706,6 +706,7 @@ def _fast_infer(self, all_inputs, inputs=None):
706706
if self.state.global_step != self._last_loaded_step:
707707
self._move_model_to_vllm_lmdeploy()
708708
self._last_loaded_step = self.state.global_step
709+
all_inputs = gather_object(inputs)
709710
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
710711
# Distribute inputs to different workers
711712
# for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker
@@ -758,9 +759,7 @@ def _fast_infer(self, all_inputs, inputs=None):
758759

759760
def _generate_completions(self, inputs):
760761
if self.use_fast_infer:
761-
all_inputs = gather_object(inputs)
762-
763-
_, outputs = self._fast_infer(all_inputs, inputs)
762+
inputs, outputs = self._fast_infer(inputs)
764763
# Slice to keep only the local part of the data
765764
process_slice = slice(
766765
self.accelerator.process_index * len(inputs),

0 commit comments

Comments
 (0)