@@ -683,15 +683,15 @@ def _prefetch(self, dataloader):
683683 if self .accelerator .num_processes > 1 :
684684 self .accelerator .wait_for_everyone ()
685685
686- def _fast_infer (self , all_inputs , inputs = None ):
686+ def _fast_infer (self , inputs ):
687687 """
688688 This function performs fast inference by managing model and optimizer offloading,
689689 loading weights if necessary, distributing inputs among workers, and generating
690690 completions using the vLLM/LMDeploy framework. It supports both synchronous and asynchronous
691691 inference modes.
692- all_inputs: all gather inputs in distributed
693692 inputs: local inputs
694693 """
694+
695695 if self .args .sleep_level > 0 and self .infer_rank >= 0 :
696696 if self .args .offload_model :
697697 self .offload_model ()
@@ -706,6 +706,7 @@ def _fast_infer(self, all_inputs, inputs=None):
706706 if self .state .global_step != self ._last_loaded_step :
707707 self ._move_model_to_vllm_lmdeploy ()
708708 self ._last_loaded_step = self .state .global_step
709+ all_inputs = gather_object (inputs )
709710 # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
710711 # Distribute inputs to different workers
711712 # for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker
@@ -758,9 +759,7 @@ def _fast_infer(self, all_inputs, inputs=None):
758759
759760 def _generate_completions (self , inputs ):
760761 if self .use_fast_infer :
761- all_inputs = gather_object (inputs )
762-
763- _ , outputs = self ._fast_infer (all_inputs , inputs )
762+ inputs , outputs = self ._fast_infer (inputs )
764763 # Slice to keep only the local part of the data
765764 process_slice = slice (
766765 self .accelerator .process_index * len (inputs ),
0 commit comments