@@ -459,7 +459,7 @@ def _template_context(self, template):
459459 template .max_length = max_length
460460
461461 @profiling_decorator
462- def _move_model_to_vllm (self ):
462+ def _move_model_to_vllm (self , skip_async_check = False ):
463463 deepspeed_plugin = self .accelerator .state .deepspeed_plugin
464464 zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin .zero_stage == 3
465465 if zero_stage_3 :
@@ -468,7 +468,7 @@ def _move_model_to_vllm(self):
468468 else :
469469 gather_if_zero3 = nullcontext
470470
471- if self .args .async_generate :
471+ if self .args .async_generate and not skip_async_check :
472472 # before sync weight, we should wait async generate finish
473473 self ._wait_queue ()
474474
@@ -754,6 +754,9 @@ def done(future):
754754 def _prefetch (self , dataloader : DataLoader ):
755755 inputs = next (iter (dataloader ))
756756 all_inputs = gather_object (inputs )
757+ if self .state .global_step != self ._last_loaded_step :
758+ self ._move_model_to_vllm (skip_async_check = True )
759+ self ._last_loaded_step = self .state .global_step
757760 outputs = self ._infer_single_or_multi_turn (all_inputs , self .request_config , is_global_inputs = True )
758761 self ._queue .put (DataCache (all_inputs , outputs ))
759762
0 commit comments