Skip to content

Commit 56e84a2

Browse files
authored
[rollout] Fix non-serializable torch.dtype bug in VLLM weight sync (#4825)
* fix server hang * backward compatiblity * rm comment
1 parent dc6f124 commit 56e84a2

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

swift/llm/infer/rollout.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
import time
99
from contextlib import asynccontextmanager, contextmanager
1010
from dataclasses import asdict
11+
from functools import wraps
1112
from itertools import chain
1213
from multiprocessing import Pipe, Process
1314
from multiprocessing.connection import Connection
14-
from typing import Dict, List, Optional, Union
15+
from typing import Dict, List, Optional, Union, get_type_hints
1516

1617
import torch
1718
import uvicorn
1819
from aiohttp import ClientConnectorError
1920
from fastapi import FastAPI
21+
from trl.scripts.vllm_serve import WeightSyncWorkerExtension
2022

2123
from swift.llm import InferArguments, RolloutArguments, SwiftPipeline
2224
from swift.llm.template.template_inputs import RolloutInferRequest
@@ -271,8 +273,7 @@ async def update_named_param(self, request: UpdateWeightsRequest):
271273
# The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10))
272274
# So with collective_rpc we need to call it this way:
273275
# llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10)))
274-
dtype = torch.__getattribute__(request.dtype.split('.')[-1])
275-
kwargs = {'method': 'update_named_param', 'args': (request.name, dtype, tuple(request.shape))}
276+
kwargs = {'method': 'update_named_param', 'args': (request.name, request.dtype, tuple(request.shape))}
276277
for connection in self.connections:
277278
connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs})
278279

@@ -377,3 +378,22 @@ def run_rollout(args: RolloutArguments, return_url: bool = False):
377378
finally:
378379
process.terminate()
379380
logger.info('The deployment process has been terminated.')
381+
382+
383+
# https://github.com/huggingface/trl/pull/3690
384+
# This patch handles backward compatibility for dtype parameter type changes in TRL:
385+
# - For TRL <= 0.19: dtype_annotation is torch.dtype (needs patching)
386+
# - For TRL > 0.19: dtype_annotation is str (no patching needed)
387+
old_update_named_param = WeightSyncWorkerExtension.update_named_param
388+
dtype_annotation = get_type_hints(old_update_named_param).get('dtype')
389+
390+
if not hasattr(WeightSyncWorkerExtension, 'old_update_named_param') and dtype_annotation == torch.dtype:
391+
392+
@wraps(old_update_named_param)
393+
def patched_update_named_param(self, name, dtype, shape) -> None:
394+
if isinstance(dtype, str):
395+
dtype = getattr(torch, dtype.split('.')[-1])
396+
return old_update_named_param(self, name, dtype, shape)
397+
398+
WeightSyncWorkerExtension.update_named_param = patched_update_named_param
399+
WeightSyncWorkerExtension.old_update_named_param = old_update_named_param

0 commit comments

Comments
 (0)