Skip to content

Commit d3b6b56

Browse files
authored
[grpo] support vllm_server_base_url for vLLMClient (#4449)
1 parent 23df7f3 commit d3b6b56

File tree

6 files changed

+35
-13
lines changed

6 files changed

+35
-13
lines changed

docs/source/Instruction/GRPO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ A conversation between User and Assistant. The user asks a question, and the Ass
209209
- use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。
210210
- vllm_mode: vLLM 集成模式,可选项为 `server``colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。
211211
- vllm_mode server 参数
212+
- vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。
212213
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用.
213214
- vllm_server_port vLLM server 服务端口,默认为8000.
214215
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s.

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ reward模型参数将在PPO、GRPO中使用。
430430
- use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。
431431
- vllm_mode: vLLM 集成模式,可选项为 `server``colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时,
432432
- vllm_mode server 参数
433+
- vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。
433434
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用。
434435
- vllm_server_port vLLM server 服务端口,默认为8000。
435436
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ The meanings of the following parameters can be referenced [here](https://huggin
442442
- use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False.
443443
- vllm_mode: Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or `colocate`
444444
- vllm_mode server parameter
445+
- vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None.
445446
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
446447
- vllm_server_port: The service port of the vLLM server. Default is 8000.
447448
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.

docs/source_en/Instruction/GRPO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ Arguments
219219
- use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False.
220220
- vllm_mode: Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or `colocate`
221221
- vllm_mode server parameter
222+
- vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None.
222223
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
223224
- vllm_server_port: The service port of the vLLM server. Default is 8000.
224225
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.

swift/trainers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ class GRPOArgumentsMixin:
169169
vllm_enable_prefix_caching: bool = True
170170
vllm_tensor_parallel_size: int = 1
171171
# external vllm (server)
172+
vllm_server_base_url: Optional[str] = None
172173
vllm_server_host: Optional[str] = None
173174
vllm_server_port: int = 8000
174175
vllm_server_timeout: float = 240.0

swift/trainers/rlhf_trainer/vllm_client.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
import atexit
66
import logging
7+
import socket
78
import time
89
from typing import List, Optional
10+
from urllib.parse import urlparse
911

1012
import requests
1113
import torch
@@ -36,10 +38,13 @@ class VLLMClient:
3638
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
3739
3840
Args:
41+
base_url (`str` or `None`, *optional*, defaults to `None`):
42+
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are
43+
ignored.
3944
host (`str`, *optional*, defaults to `"0.0.0.0"`):
40-
IP address of the vLLM server.
45+
IP address of the vLLM server. Ignored if `base_url` is provided.
4146
server_port (`int`, *optional*, defaults to `8000`):
42-
Port number of the vLLM server.
47+
Port number of the vLLM server. Ignored if `base_url` is provided.
4348
group_port (`int`, *optional*, defaults to `51216`):
4449
Port number for the weight update group.
4550
connection_timeout (`float`, *optional*, defaults to `0.0`):
@@ -48,6 +53,7 @@ class VLLMClient:
4853
"""
4954

5055
def __init__(self,
56+
base_url: Optional[str] = None,
5157
host: str = '0.0.0.0',
5258
server_port: int = 8000,
5359
group_port: int = 51216,
@@ -56,8 +62,17 @@ def __init__(self,
5662
raise ImportError('vLLM is not installed. Please install it with `pip install vllm`.')
5763

5864
self.session = requests.Session()
59-
self.host = host
60-
self.server_port = server_port
65+
if base_url is not None:
66+
# Parse the base_url to extract host and port
67+
parsed_url = urlparse(base_url)
68+
self.host = socket.gethostbyname(parsed_url.hostname)
69+
scheme = parsed_url.scheme or 'http'
70+
self.base_url = f'{scheme}://{parsed_url.netloc}{parsed_url.path}'
71+
else:
72+
self.host = host
73+
self.server_port = server_port
74+
self.base_url = f'http://{self.host}:{self.server_port}'
75+
6176
self.group_port = group_port
6277
self.check_server(connection_timeout) # check server and fail after timeout
6378

@@ -72,7 +87,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
7287
total_timeout (`float`, *optional*, defaults to `0.0`):
7388
Total timeout duration in seconds.
7489
"""
75-
url = f'http://{self.host}:{self.server_port}/health/'
90+
url = f'{self.base_url}/health/'
7691
start_time = time.time() # Record the start time
7792

7893
while True:
@@ -83,10 +98,12 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
8398
elapsed_time = time.time() - start_time
8499
if elapsed_time >= total_timeout:
85100
raise ConnectionError(
86-
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
87-
'seconds. Make sure the server is running by running `swift deploy`.') from exc
101+
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make "
102+
'sure the server is running by running `trl vllm-serve`.') from exc
88103
else:
89104
if response.status_code == 200:
105+
if 'X-Forwarded-For' in response.headers:
106+
self.host = response.headers['X-Forwarded-For']
90107
logger.info('Server is up!')
91108
return None
92109

@@ -104,7 +121,7 @@ def infer(
104121
use_tqdm: Optional[bool] = None,
105122
adapter_request: Optional[AdapterRequest] = None,
106123
):
107-
url = f'http://{self.host}:{self.server_port}/infer/'
124+
url = f'{self.base_url}/generate/'
108125
response = self.session.post(
109126
url,
110127
json={
@@ -126,7 +143,7 @@ def init_communicator(self):
126143
Initializes the weight update group in a distributed setup for model synchronization.
127144
"""
128145
# Get the tensor parallel size from the server
129-
url = f'http://{self.host}:{self.server_port}/get_world_size/'
146+
url = f'{self.base_url}/get_world_size/'
130147
response = requests.get(url)
131148
if response.status_code == 200:
132149
vllm_world_size = response.json()['world_size']
@@ -137,7 +154,7 @@ def init_communicator(self):
137154
self.rank = vllm_world_size # the client's rank is the last process
138155

139156
# Initialize weight update group
140-
url = f'http://{self.host}:{self.server_port}/init_communicator/'
157+
url = f'{self.base_url}/init_communicator/'
141158
# In the server side, the host is set to 0.0.0.0
142159
response = self.session.post(url, json={'host': '0.0.0.0', 'port': self.group_port, 'world_size': world_size})
143160
if response.status_code != 200:
@@ -166,7 +183,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
166183
Tensor containing the updated weights.
167184
"""
168185
dtype, shape = str(weights.dtype), tuple(weights.shape)
169-
url = f'http://{self.host}:{self.server_port}/update_named_param/'
186+
url = f'{self.base_url}/update_named_param/'
170187
response = self.session.post(url, json={'name': name, 'dtype': dtype, 'shape': shape})
171188
if response.status_code != 200:
172189
raise Exception(f'Request failed: {response.status_code}, {response.text}')
@@ -191,7 +208,7 @@ def reset_prefix_cache(self):
191208
"""
192209
Resets the prefix cache for the model.
193210
"""
194-
url = f'http://{self.host}:{self.server_port}/reset_prefix_cache/'
211+
url = f'{self.base_url}/reset_prefix_cache/'
195212
response = self.session.post(url)
196213
if response.status_code != 200:
197214
raise Exception(f'Request failed: {response.status_code}, {response.text}')
@@ -200,7 +217,7 @@ def close_communicator(self):
200217
"""
201218
Closes the weight update group and cleans up the communication group.
202219
"""
203-
url = f'http://{self.host}:{self.server_port}/close_communicator/'
220+
url = f'{self.base_url}/close_communicator/'
204221

205222
try:
206223
response = self.session.post(url)

0 commit comments

Comments
 (0)