diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index a25d93df83..d17bcda895 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -394,9 +394,13 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): engine_kwargs = kwargs.get('engine_kwargs', {}) # for RL rollout model weight sync engine_kwargs.update({'worker_extension_cls': 'swift.llm.infer.rollout.WeightSyncWorkerExtension'}) - # Use load_format from engine_kwargs if provided, otherwise default to 'dummy' - if 'load_format' not in engine_kwargs: - engine_kwargs['load_format'] = 'dummy' + + # For RL rollout, we use 'dummy' load_format to prevent vLLM from loading weights from disk, + # as they will be synced from the trainer process. + # This will accelerate the rollout speed. + load_format = engine_kwargs.pop('load_format', 'dummy') + kwargs['load_format'] = load_format + if args.vllm_use_async_engine and args.vllm_data_parallel_size > 1: engine_kwargs['data_parallel_size'] = args.vllm_data_parallel_size kwargs['engine_kwargs'] = engine_kwargs