Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions swift/llm/infer/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,13 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs):
engine_kwargs = kwargs.get('engine_kwargs', {})
# for RL rollout model weight sync
engine_kwargs.update({'worker_extension_cls': 'swift.llm.infer.rollout.WeightSyncWorkerExtension'})
# Use load_format from engine_kwargs if provided, otherwise default to 'dummy'
if 'load_format' not in engine_kwargs:
engine_kwargs['load_format'] = 'dummy'

# For RL rollout, we use 'dummy' load_format to prevent vLLM from loading weights from disk,
# as they will be synced from the trainer process.
# This will accelerate the rollout speed.
load_format = engine_kwargs.pop('load_format', 'dummy')
kwargs['load_format'] = load_format

if args.vllm_use_async_engine and args.vllm_data_parallel_size > 1:
engine_kwargs['data_parallel_size'] = args.vllm_data_parallel_size
kwargs['engine_kwargs'] = engine_kwargs
Expand Down
Loading