Skip to content

Commit 7066b69

Browse files
hiworldwzjWeichao Luo
andauthored
pd master support multinode. (#828)
Co-authored-by: Weichao Luo <luoweichao@sensetime.com>
1 parent c2dbc8f commit 7066b69

File tree

20 files changed

+710
-131
lines changed

20 files changed

+710
-131
lines changed

docs/CN/source/getting_started/quickstart.rst

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393

9494
.. code-block:: console
9595
96-
$ CUDA_VISIBLE_DEVICES=0 python -m lightllm.server.api_server \
96+
$ python -m lightllm.server.api_server \
9797
$ --model_dir /your/model/path \
9898
$ --run_mode "pd_master" \
9999
$ --host /your/host/ip \
@@ -165,3 +165,107 @@
165165
$ cd test
166166
$ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream
167167
168+
169+
3. PD 分离多PD_Master节点类型启动模型服务
170+
-------------------------
171+
查找本机IP
172+
173+
.. code-block:: console
174+
175+
$ hostname -i
176+
177+
运行MPS(可选, 有mps支持性能会好特别多,但是部分显卡和驱动环境开启mps会容易出现错误,建议升级驱动到较高版本,特别是H系列卡)
178+
179+
.. code-block:: console
180+
181+
$ nvidia-cuda-mps-control -d
182+
183+
184+
运行config_server服务
185+
.. code-block:: console
186+
187+
$ python -m lightllm.server.api_server \
188+
$ --run_mode "config_server" \
189+
$ --config_server_host /your/host/ip \
190+
$ --config_server_port 60088 \
191+
192+
193+
运行pd_master服务, 在多pd_master节点模式下,可以开启多个pd_master服务,来实现负载均衡,单个pd_master因为python gil锁的原因
194+
其并发性能存在上限。
195+
196+
.. code-block:: console
197+
198+
$ python -m lightllm.server.api_server \
199+
$ --model_dir /your/model/path \
200+
$ --run_mode "pd_master" \
201+
$ --host /your/host/ip \
202+
$ --port 60011 \
203+
$ --config_server_host <config_server_host> \
204+
$ --config_server_port <config_server_port>
205+
206+
新建终端,运行prefill服务
207+
208+
.. code-block:: console
209+
210+
$ CUDA_VISIBLE_DEVICES=0,1 KV_TRANS_USE_P2P=1 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /data/fengdahu/model/Qwen2-7B/ \
211+
$ --run_mode "prefill" \
212+
$ --host /your/host/ip \
213+
$ --port 8017 \
214+
$ --tp 2 \
215+
$ --nccl_port 2732 \
216+
$ --max_total_token_num 400000 \
217+
$ --tokenizer_mode fast \
218+
$ --use_dynamic_prompt_cache \
219+
$ --max_req_total_len 16000 \
220+
$ --running_max_req_size 128 \
221+
$ --disable_cudagraph \
222+
$ --config_server_host <config_server_host> \
223+
$ --config_server_port <config_server_port>
224+
225+
新建终端,运行decoding服务
226+
227+
.. code-block:: console
228+
229+
$ CUDA_VISIBLE_DEVICES=2,3 KV_TRANS_USE_P2P=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /data/fengdahu/model/Qwen2-7B/ \
230+
$ --run_mode "decode" \
231+
$ --host /your/host/ip \
232+
$ --port 8118 \
233+
$ --nccl_port 12322 \
234+
$ --tp 2 \
235+
$ --max_total_token_num 400000 \
236+
$ --graph_max_len_in_batch 2048 \
237+
$ --graph_max_batch_size 16 \
238+
$ --tokenizer_mode fast \
239+
$ --use_dynamic_prompt_cache \
240+
$ --config_server_host <config_server_host> \
241+
$ --config_server_port <config_server_port>
242+
243+
.. note::
244+
prefill和decoding阶段的tp大小保持一致, 目前可以支持 prefill 和 decode 节点的数量是变化的,同时prefill 和 decode可以跨机部署。
245+
246+
247+
4. (可选)测试模型服务
248+
-------------------------
249+
250+
在新的终端,使用下面的指令对模型服务进行测试, 在多pd_master模式下,每个pd_master都可以作为访问入口:
251+
252+
.. code-block:: console
253+
254+
$ curl http://server_ip:server_port/generate \
255+
$ -H "Content-Type: application/json" \
256+
$ -d '{
257+
$ "inputs": "What is AI?",
258+
$ "parameters":{
259+
$ "max_new_tokens":17,
260+
$ "frequency_penalty":1
261+
$ }
262+
$ }'
263+
264+
265+
对于DeepSeek-R1模型,可以用如下脚本进行测试:
266+
267+
.. code-block:: console
268+
269+
$ cd test
270+
$ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream
271+

lightllm/server/api_cli.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
77
parser.add_argument(
88
"--run_mode",
99
type=str,
10-
choices=["normal", "prefill", "decode", "pd_master"],
10+
choices=["normal", "prefill", "decode", "pd_master", "config_server"],
1111
default="normal",
12-
help="set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode",
12+
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
13+
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
14+
specifically designed for large-scale, high-concurrency scenarios where `pd_master` encounters
15+
significant CPU bottlenecks.""",
1316
)
1417
parser.add_argument("--host", type=str, default="127.0.0.1")
1518
parser.add_argument("--port", type=int, default=8000)
@@ -39,7 +42,18 @@ def make_argument_parser() -> argparse.ArgumentParser:
3942
default=42000,
4043
help="p d mode, decode node used for kv move manager rpyc server port",
4144
)
42-
45+
parser.add_argument(
46+
"--config_server_host",
47+
type=str,
48+
default=None,
49+
help="The host address for the config server in config_server mode.",
50+
)
51+
parser.add_argument(
52+
"--config_server_port",
53+
type=int,
54+
default=None,
55+
help="The port number for the config server in config_server mode.",
56+
)
4357
parser.add_argument(
4458
"--model_name",
4559
type=str,

lightllm/server/api_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
66
parser = make_argument_parser()
77
args = parser.parse_args()
8-
from .api_start import pd_master_start, normal_or_p_d_start
8+
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start
99

1010
if args.run_mode == "pd_master":
1111
pd_master_start(args)
12+
elif args.run_mode == "config_server":
13+
config_server_start(args)
1214
else:
1315
normal_or_p_d_start(args)

lightllm/server/api_start.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,14 @@ def pd_master_start(args):
284284
if args.run_mode != "pd_master":
285285
return
286286

287+
# when use config_server to support multi pd_master node, we
288+
# need generate unique node id for each pd_master node.
289+
# otherwise, we use the 0 for single pd_master node.
290+
if args.config_server_host and args.config_server_port:
291+
args.pd_node_id = uuid.uuid4().int
292+
else:
293+
args.pd_node_id = 0
294+
287295
logger.info(f"use tgi api: {args.use_tgi_api}")
288296
logger.info(f"all start args:{args}")
289297

@@ -330,3 +338,37 @@ def pd_master_start(args):
330338

331339
setup_signal_handlers(http_server_process, process_manager)
332340
http_server_process.wait()
341+
342+
343+
def config_server_start(args):
344+
set_unique_server_name(args)
345+
if args.run_mode != "config_server":
346+
return
347+
348+
logger.info(f"all start args:{args}")
349+
350+
set_env_start_args(args)
351+
352+
command = [
353+
"gunicorn",
354+
"--workers",
355+
"1",
356+
"--worker-class",
357+
"uvicorn.workers.UvicornWorker",
358+
"--bind",
359+
f"{args.config_server_host}:{args.config_server_port}",
360+
"--log-level",
361+
"info",
362+
"--access-logfile",
363+
"-",
364+
"--error-logfile",
365+
"-",
366+
"--preload",
367+
"lightllm.server.config_server.api_http:app",
368+
"--timeout",
369+
f"{get_lightllm_gunicorn_time_out_seconds()}",
370+
]
371+
372+
http_server_process = subprocess.Popen(command)
373+
setup_signal_handlers(http_server_process, process_manager)
374+
http_server_process.wait()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
This module implements a configuration service designed to facilitate the
3+
registration and retrieval of information in a PD separation mode. It
4+
allows various nodes to register their own information and query global
5+
configuration details efficiently.
6+
7+
Key Features:
8+
- Node registration: Enables nodes to register their specific information.
9+
- Global configuration query: Provides mechanisms for querying shared
10+
configuration data across the system.
11+
- Designed for distributed systems operating in PD separation mode.
12+
"""
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query
2+
from threading import Lock
3+
from typing import Dict
4+
from fastapi.responses import JSONResponse
5+
from lightllm.utils.log_utils import init_logger
6+
from ..pd_io_struct import PD_Master_Obj
7+
import base64
8+
import pickle
9+
import os
10+
import requests
11+
12+
logger = init_logger(__name__)
13+
app = FastAPI()
14+
15+
registered_pd_master_objs: Dict[str, PD_Master_Obj] = {}
16+
registered_pd_master_obj_lock = Lock()
17+
18+
global_req_id = 0
19+
global_req_id_lock = Lock()
20+
21+
22+
@app.get("/liveness")
23+
@app.post("/liveness")
24+
def liveness():
25+
return {"status": "ok"}
26+
27+
28+
@app.get("/readiness")
29+
@app.post("/readiness")
30+
def readiness():
31+
return {"status": "ok"}
32+
33+
34+
@app.get("/healthz", summary="Check server health")
35+
@app.get("/health", summary="Check server health")
36+
@app.head("/health", summary="Check server health")
37+
async def healthcheck(request: Request):
38+
return JSONResponse({"message": "Ok"}, status_code=200)
39+
40+
41+
@app.websocket("/pd_master_register")
42+
async def websocket_endpoint(websocket: WebSocket):
43+
await websocket.accept()
44+
client_ip, client_port = websocket.client
45+
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
46+
registered_pd_master_obj: PD_Master_Obj = pickle.loads(await websocket.receive_bytes())
47+
logger.info(f"recieved registered_pd_master_obj {registered_pd_master_obj}")
48+
with registered_pd_master_obj_lock:
49+
registered_pd_master_objs[registered_pd_master_obj.node_id] = registered_pd_master_obj
50+
51+
try:
52+
while True:
53+
data = await websocket.receive_text()
54+
assert data == "heartbeat"
55+
except (WebSocketDisconnect, Exception, RuntimeError) as e:
56+
logger.error(f"registered_pd_master_obj {registered_pd_master_obj} has error {str(e)}")
57+
logger.exception(str(e))
58+
finally:
59+
logger.error(f"registered_pd_master_obj {registered_pd_master_obj} removed")
60+
with registered_pd_master_obj_lock:
61+
registered_pd_master_objs.pop(registered_pd_master_obj.node_id, None)
62+
return
63+
64+
65+
@app.get("/registered_objects")
66+
async def get_registered_objects():
67+
with registered_pd_master_obj_lock:
68+
serialized_data = pickle.dumps(registered_pd_master_objs)
69+
base64_encoded = base64.b64encode(serialized_data).decode("utf-8")
70+
return {"data": base64_encoded}
71+
72+
73+
@app.get("/allocate_global_unique_id_range")
74+
async def allocate_global_id_range():
75+
"""
76+
Allocate a global ID range for the requesting client without requiring parameters.
77+
78+
Returns:
79+
dict: A dictionary containing the start and end of the allocated ID range.
80+
81+
Example HTTP client usage:
82+
```python
83+
response = requests.get("http://<server_address>/allocate_global_unique_id_range")
84+
print(response.json())
85+
```
86+
"""
87+
global global_req_id
88+
range_size = 800000
89+
with global_req_id_lock:
90+
if global_req_id + range_size > 2 ** 63 - 1:
91+
global_req_id = 0
92+
start_id = global_req_id
93+
global_req_id += range_size
94+
end_id = global_req_id
95+
96+
return {"start_id": start_id, "end_id": end_id}

lightllm/server/core/objs/req.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ class Req(ctypes.Structure):
6060
_fields_ = [
6161
("index_in_shm_mem", ctypes.c_int),
6262
("ref_count", ctypes.c_int), # 个人不要操作这个计数 # 个人不要操作这个引用计数
63-
("request_id", ctypes.c_int), # 引用计数
64-
("group_req_id", ctypes.c_int),
63+
("request_id", ctypes.c_int64), # 引用计数
64+
("group_req_id", ctypes.c_int64),
6565
("input_len", ctypes.c_int),
6666
("alloc_shm_numpy_len", ctypes.c_int),
6767
("shm_infer_released", ctypes.c_bool), # 推理进程用于标记请求对象已经被推理进程释放,router进程得到信息后亦可释放shm req对象

lightllm/server/core/objs/sampling_params.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ class DecodeNode(ctypes.Structure):
205205
("ip", ctypes.c_int32 * 4),
206206
("rpyc_port", ctypes.c_int),
207207
("max_new_tokens", ctypes.c_int),
208+
# 记录当前请求使用的 pd_master 节点的 id
209+
("pd_master_node_id_high", ctypes.c_uint64),
210+
("pd_master_node_id_low", ctypes.c_uint64),
208211
]
209212

210213
def initialize(self, data_dict):
@@ -224,15 +227,19 @@ def initialize(self, data_dict):
224227
self.rpyc_port = data_dict["rpyc_port"]
225228
self.max_new_tokens = data_dict["max_new_tokens"]
226229

230+
pd_master_node_id = data_dict["pd_master_node_id"]
231+
self.pd_master_node_id_high = (pd_master_node_id >> 64) & 0xFFFFFFFFFFFFFFFF
232+
self.pd_master_node_id_low = pd_master_node_id & 0xFFFFFFFFFFFFFFFF
233+
227234
def to_dict(self):
228235
if not self.exists:
229236
return None
230-
uuid_int = (self.node_id_high << 64) | self.node_id_low
231237
return {
232-
"node_id": uuid_int,
238+
"node_id": ((self.node_id_high << 64) | self.node_id_low),
233239
"ip": ".".join(str(self.ip[i]) for i in range(4)),
234240
"rpyc_port": self.rpyc_port,
235241
"max_new_tokens": self.max_new_tokens,
242+
"pd_master_node_id": ((self.pd_master_node_id_high << 64) | self.pd_master_node_id_low),
236243
}
237244

238245

@@ -264,7 +271,7 @@ class SamplingParams(ctypes.Structure):
264271
("allowed_token_ids", AllowedTokenIds),
265272
("stop_sequences", StopSequenceGroups),
266273
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
267-
("group_request_id", ctypes.c_int), # p d mode used params
274+
("group_request_id", ctypes.c_int64), # p d mode used params
268275
("suggested_dp_index", ctypes.c_int), # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index
269276
("move_kv_to_decode_node", DecodeNode), # move kv to deocde node, only used in pd mode
270277
("skip_special_tokens", ctypes.c_bool), # whether to skip special tokens when decoding

lightllm/server/core/objs/start_args_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class StartArgs:
1515
)
1616
pd_master_ip: str = field(default="127.0.0.1")
1717
pd_master_port: int = field(default=1212)
18+
config_server_host: str = field(default=None)
19+
config_server_port: int = field(default=None)
1820
pd_decode_rpyc_port: int = field(default=42000)
1921
model_name: str = field(default="default_model_name")
2022
model_dir: Optional[str] = field(default=None)

0 commit comments

Comments
 (0)