Skip to content

Commit 45df767

Browse files
committed
fix rollout
1 parent 0942292 commit 45df767

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

swift/llm/infer/rollout.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def get_rollout_engine_type(args: RolloutArguments, engine: GRPOVllmEngine):
8686
def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int, connection: Connection) -> None:
8787
# Set required environment variables for DP to work with vLLM
8888
args._import_external_plugins()
89+
args._init_custom_register()
8990
os.environ['VLLM_DP_RANK'] = str(data_parallel_rank)
9091
os.environ['VLLM_DP_RANK_LOCAL'] = str(data_parallel_rank)
9192
os.environ['VLLM_DP_SIZE'] = str(args.vllm_data_parallel_size)
@@ -119,6 +120,7 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast
119120
connection: Connection) -> None:
120121
# Set required environment variables for DP to work with vLLM
121122
args._import_external_plugins()
123+
args._init_custom_register()
122124
engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(None))
123125

124126
rollout_engine = get_rollout_engine_type(args, engine)

0 commit comments

Comments
 (0)