Skip to content

Commit c967bed

Browse files
hjh0119Jintao-Huang
authored andcommitted
[grpo] model weight synchronization before first turn rollout with async generation (#4584)
* sync-weight-in-async-generate * fix wait
1 parent 64aa542 commit c967bed

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def _template_context(self, template):
459459
template.max_length = max_length
460460

461461
@profiling_decorator
462-
def _move_model_to_vllm(self):
462+
def _move_model_to_vllm(self, skip_async_check=False):
463463
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
464464
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
465465
if zero_stage_3:
@@ -468,7 +468,7 @@ def _move_model_to_vllm(self):
468468
else:
469469
gather_if_zero3 = nullcontext
470470

471-
if self.args.async_generate:
471+
if self.args.async_generate and not skip_async_check:
472472
# before sync weight, we should wait async generate finish
473473
self._wait_queue()
474474

@@ -754,6 +754,9 @@ def done(future):
754754
def _prefetch(self, dataloader: DataLoader):
755755
inputs = next(iter(dataloader))
756756
all_inputs = gather_object(inputs)
757+
if self.state.global_step != self._last_loaded_step:
758+
self._move_model_to_vllm(skip_async_check=True)
759+
self._last_loaded_step = self.state.global_step
757760
outputs = self._infer_single_or_multi_turn(all_inputs, self.request_config, is_global_inputs=True)
758761
self._queue.put(DataCache(all_inputs, outputs))
759762

0 commit comments

Comments
 (0)