|
8 | 8 | import time
|
9 | 9 | from contextlib import asynccontextmanager, contextmanager
|
10 | 10 | from dataclasses import asdict
|
| 11 | +from functools import wraps |
11 | 12 | from itertools import chain
|
12 | 13 | from multiprocessing import Pipe, Process
|
13 | 14 | from multiprocessing.connection import Connection
|
14 |
| -from typing import Dict, List, Optional, Union |
| 15 | +from typing import Dict, List, Optional, Union, get_type_hints |
15 | 16 |
|
16 | 17 | import torch
|
17 | 18 | import uvicorn
|
18 | 19 | from aiohttp import ClientConnectorError
|
19 | 20 | from fastapi import FastAPI
|
| 21 | +from trl.scripts.vllm_serve import WeightSyncWorkerExtension |
20 | 22 |
|
21 | 23 | from swift.llm import InferArguments, RolloutArguments, SwiftPipeline
|
22 | 24 | from swift.llm.template.template_inputs import RolloutInferRequest
|
@@ -271,8 +273,7 @@ async def update_named_param(self, request: UpdateWeightsRequest):
|
271 | 273 | # The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10))
|
272 | 274 | # So with collective_rpc we need to call it this way:
|
273 | 275 | # llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10)))
|
274 |
| - dtype = torch.__getattribute__(request.dtype.split('.')[-1]) |
275 |
| - kwargs = {'method': 'update_named_param', 'args': (request.name, dtype, tuple(request.shape))} |
| 276 | + kwargs = {'method': 'update_named_param', 'args': (request.name, request.dtype, tuple(request.shape))} |
276 | 277 | for connection in self.connections:
|
277 | 278 | connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs})
|
278 | 279 |
|
@@ -377,3 +378,22 @@ def run_rollout(args: RolloutArguments, return_url: bool = False):
|
377 | 378 | finally:
|
378 | 379 | process.terminate()
|
379 | 380 | logger.info('The deployment process has been terminated.')
|
| 381 | + |
| 382 | + |
| 383 | +# https://github.com/huggingface/trl/pull/3690 |
| 384 | +# This patch handles backward compatibility for dtype parameter type changes in TRL: |
| 385 | +# - For TRL <= 0.19: dtype_annotation is torch.dtype (needs patching) |
| 386 | +# - For TRL > 0.19: dtype_annotation is str (no patching needed) |
| 387 | +old_update_named_param = WeightSyncWorkerExtension.update_named_param |
| 388 | +dtype_annotation = get_type_hints(old_update_named_param).get('dtype') |
| 389 | + |
| 390 | +if not hasattr(WeightSyncWorkerExtension, 'old_update_named_param') and dtype_annotation == torch.dtype: |
| 391 | + |
| 392 | + @wraps(old_update_named_param) |
| 393 | + def patched_update_named_param(self, name, dtype, shape) -> None: |
| 394 | + if isinstance(dtype, str): |
| 395 | + dtype = getattr(torch, dtype.split('.')[-1]) |
| 396 | + return old_update_named_param(self, name, dtype, shape) |
| 397 | + |
| 398 | + WeightSyncWorkerExtension.update_named_param = patched_update_named_param |
| 399 | + WeightSyncWorkerExtension.old_update_named_param = old_update_named_param |
0 commit comments