Skip to content

Commit 6656887

Browse files
authored
vLLM 0.8.3 support for GRPO colocate mode (#3820)
* should work * tp>1 * seed=0 --------- Co-authored-by: hjh <[email protected]>
1 parent 4c76918 commit 6656887

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,17 @@ def prepare_vllm(self, model, fast_infer_device):
412412
from swift.llm import VllmEngine
413413
from swift.llm.infer.infer_engine import GRPOVllmEngine
414414
_, _, _, local_world_size = get_dist_setting()
415+
if self.args.tensor_parallel_size > 1:
416+
vllm_kwargs = {'distributed_executor_backend': 'external_launcher'}
417+
else:
418+
vllm_kwargs = {}
415419
if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
416420
# Compatibility with TP
417421
cls = GRPOVllmEngine
418-
vllm_kwargs = {'distributed_executor_backend': 'external_launcher'}
422+
engine_kwargs = {'seed': 0}
419423
else:
420424
cls = VllmEngine
421-
vllm_kwargs = {}
425+
engine_kwargs = {}
422426
with Swift.grpo_context(model, self.template.processor):
423427
self.engine = cls(
424428
model.model_dir,
@@ -435,6 +439,7 @@ def prepare_vllm(self, model, fast_infer_device):
435439
enable_sleep_mode=self.args.sleep_level > 0,
436440
use_async_engine=False,
437441
max_model_len=self.args.vllm_max_model_len,
442+
engine_kwargs=engine_kwargs,
438443
**vllm_kwargs)
439444
self.engine.default_template = self.template
440445

0 commit comments

Comments
 (0)