Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7785e55
pd handle loop refactor.
hiworldwzj Apr 3, 2025
27faff6
fix
hiworldwzj Apr 3, 2025
452a0ac
single kv transfer process for pd.
Mar 7, 2025
9d6e4fd
fix name.
Mar 10, 2025
083b9d6
fix style.
Mar 10, 2025
ac076f3
one kv trans process per tp.
Mar 11, 2025
a7367e6
fix.
Mar 19, 2025
efb6c50
fixup.
Mar 19, 2025
2092ea6
new pd code.
hiworldwzj Apr 9, 2025
a3831ed
format
hiworldwzj Apr 10, 2025
2189ffe
fix
hiworldwzj Apr 10, 2025
89e4068
fix.
hiworldwzj Apr 10, 2025
a01bd90
fix
hiworldwzj Apr 10, 2025
e41f365
fix
hiworldwzj Apr 10, 2025
7453390
fix
hiworldwzj Apr 10, 2025
1373a07
format.
hiworldwzj Apr 10, 2025
d9d3f4f
reformat.
hiworldwzj Apr 10, 2025
18437d2
add async nccl connect.
hiworldwzj Apr 10, 2025
5dcc4b8
fix
hiworldwzj Apr 11, 2025
923d6e3
fix
hiworldwzj Apr 11, 2025
66c5d28
fix
hiworldwzj Apr 11, 2025
65b04b7
fix
hiworldwzj Apr 11, 2025
20cf30d
add config server module.
hiworldwzj Apr 14, 2025
cf3bbb5
Merge branch 'wzj_pd' into pd_master
hiworldwzj Apr 14, 2025
58d5d8d
Merge branch 'wzj' into pd_master
hiworldwzj Apr 14, 2025
2223da1
add config_server first impl.
hiworldwzj Apr 14, 2025
e175363
add pd master regist to config_server.
hiworldwzj Apr 15, 2025
35c9dbd
Merge remote-tracking branch 'origin/main' into pd_master
hiworldwzj Apr 15, 2025
9e59146
httpserver query config_server to connect all pd_master.
hiworldwzj Apr 15, 2025
ee07802
fix register bug.
hiworldwzj Apr 15, 2025
e9d8758
add pd_master_node_id and run ok first.
hiworldwzj Apr 15, 2025
f3f29e6
Merge remote-tracking branch 'origin/main' into pd_master
hiworldwzj Apr 16, 2025
2990df8
add global req_id alloc.
hiworldwzj Apr 16, 2025
671351d
fix
hiworldwzj Apr 16, 2025
2e03c78
fix.
hiworldwzj Apr 16, 2025
9560b20
add docs.
hiworldwzj Apr 16, 2025
db1f8c1
reformat.
hiworldwzj Apr 16, 2025
619fcc8
fix
hiworldwzj Apr 16, 2025
7969f91
fix req id type.
hiworldwzj Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion docs/CN/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@

.. code-block:: console

$ CUDA_VISIBLE_DEVICES=0 python -m lightllm.server.api_server \
$ python -m lightllm.server.api_server \
$ --model_dir /your/model/path \
$ --run_mode "pd_master" \
$ --host /your/host/ip \
Expand Down Expand Up @@ -165,3 +165,107 @@
$ cd test
$ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream


3. PD 分离多PD_Master节点类型启动模型服务
-------------------------
查找本机IP

.. code-block:: console

$ hostname -i

运行MPS(可选, 有mps支持性能会好特别多,但是部分显卡和驱动环境开启mps会容易出现错误,建议升级驱动到较高版本,特别是H系列卡)

.. code-block:: console

$ nvidia-cuda-mps-control -d


运行config_server服务
.. code-block:: console

$ python -m lightllm.server.api_server \
$ --run_mode "config_server" \
$ --config_server_host /your/host/ip \
$ --config_server_port 60088 \


运行pd_master服务, 在多pd_master节点模式下,可以开启多个pd_master服务,来实现负载均衡,单个pd_master因为python gil锁的原因
其并发性能存在上限。

.. code-block:: console

$ python -m lightllm.server.api_server \
$ --model_dir /your/model/path \
$ --run_mode "pd_master" \
$ --host /your/host/ip \
$ --port 60011 \
$ --config_server_host <config_server_host> \
$ --config_server_port <config_server_port>

新建终端,运行prefill服务

.. code-block:: console

$ CUDA_VISIBLE_DEVICES=0,1 KV_TRANS_USE_P2P=1 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /data/fengdahu/model/Qwen2-7B/ \
$ --run_mode "prefill" \
$ --host /your/host/ip \
$ --port 8017 \
$ --tp 2 \
$ --nccl_port 2732 \
$ --max_total_token_num 400000 \
$ --tokenizer_mode fast \
$ --use_dynamic_prompt_cache \
$ --max_req_total_len 16000 \
$ --running_max_req_size 128 \
$ --disable_cudagraph \
$ --config_server_host <config_server_host> \
$ --config_server_port <config_server_port>

新建终端,运行decoding服务

.. code-block:: console

$ CUDA_VISIBLE_DEVICES=2,3 KV_TRANS_USE_P2P=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /data/fengdahu/model/Qwen2-7B/ \
$ --run_mode "decode" \
$ --host /your/host/ip \
$ --port 8118 \
$ --nccl_port 12322 \
$ --tp 2 \
$ --max_total_token_num 400000 \
$ --graph_max_len_in_batch 2048 \
$ --graph_max_batch_size 16 \
$ --tokenizer_mode fast \
$ --use_dynamic_prompt_cache \
$ --config_server_host <config_server_host> \
$ --config_server_port <config_server_port>

.. note::
prefill和decoding阶段的tp大小保持一致, 目前可以支持 prefill 和 decode 节点的数量是变化的,同时prefill 和 decode可以跨机部署。


4. (可选)测试模型服务
-------------------------

在新的终端,使用下面的指令对模型服务进行测试, 在多pd_master模式下,每个pd_master都可以作为访问入口:

.. code-block:: console

$ curl http://server_ip:server_port/generate \
$ -H "Content-Type: application/json" \
$ -d '{
$ "inputs": "What is AI?",
$ "parameters":{
$ "max_new_tokens":17,
$ "frequency_penalty":1
$ }
$ }'


对于DeepSeek-R1模型,可以用如下脚本进行测试:

.. code-block:: console

$ cd test
$ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream

20 changes: 17 additions & 3 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
type=str,
choices=["normal", "prefill", "decode", "pd_master"],
choices=["normal", "prefill", "decode", "pd_master", "config_server"],
default="normal",
help="set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode",
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
specifically designed for large-scale, high-concurrency scenarios where `pd_master` encounters
significant CPU bottlenecks.""",
)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
Expand Down Expand Up @@ -39,7 +42,18 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=42000,
help="p d mode, decode node used for kv move manager rpyc server port",
)

parser.add_argument(
"--config_server_host",
type=str,
default=None,
help="The host address for the config server in config_server mode.",
)
parser.add_argument(
"--config_server_port",
type=int,
default=None,
help="The port number for the config server in config_server mode.",
)
parser.add_argument(
"--model_name",
type=str,
Expand Down
4 changes: 3 additions & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
parser = make_argument_parser()
args = parser.parse_args()
from .api_start import pd_master_start, normal_or_p_d_start
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start

if args.run_mode == "pd_master":
pd_master_start(args)
elif args.run_mode == "config_server":
config_server_start(args)
else:
normal_or_p_d_start(args)
42 changes: 42 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ def pd_master_start(args):
if args.run_mode != "pd_master":
return

# when use config_server to support multi pd_master node, we
# need generate unique node id for each pd_master node.
# otherwise, we use the 0 for single pd_master node.
if args.config_server_host and args.config_server_port:
args.pd_node_id = uuid.uuid4().int
else:
args.pd_node_id = 0

logger.info(f"use tgi api: {args.use_tgi_api}")
logger.info(f"all start args:{args}")

Expand Down Expand Up @@ -330,3 +338,37 @@ def pd_master_start(args):

setup_signal_handlers(http_server_process, process_manager)
http_server_process.wait()


def config_server_start(args):
set_unique_server_name(args)
if args.run_mode != "config_server":
return

logger.info(f"all start args:{args}")

set_env_start_args(args)

command = [
"gunicorn",
"--workers",
"1",
"--worker-class",
"uvicorn.workers.UvicornWorker",
"--bind",
f"{args.config_server_host}:{args.config_server_port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"--preload",
"lightllm.server.config_server.api_http:app",
"--timeout",
f"{get_lightllm_gunicorn_time_out_seconds()}",
]

http_server_process = subprocess.Popen(command)
setup_signal_handlers(http_server_process, process_manager)
http_server_process.wait()
12 changes: 12 additions & 0 deletions lightllm/server/config_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
This module implements a configuration service designed to facilitate the
registration and retrieval of information in a PD separation mode. It
allows various nodes to register their own information and query global
configuration details efficiently.

Key Features:
- Node registration: Enables nodes to register their specific information.
- Global configuration query: Provides mechanisms for querying shared
configuration data across the system.
- Designed for distributed systems operating in PD separation mode.
"""
96 changes: 96 additions & 0 deletions lightllm/server/config_server/api_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query
from threading import Lock
from typing import Dict
from fastapi.responses import JSONResponse
from lightllm.utils.log_utils import init_logger
from ..pd_io_struct import PD_Master_Obj
import base64
import pickle
import os
import requests

logger = init_logger(__name__)
app = FastAPI()

registered_pd_master_objs: Dict[str, PD_Master_Obj] = {}
registered_pd_master_obj_lock = Lock()

global_req_id = 0
global_req_id_lock = Lock()


@app.get("/liveness")
@app.post("/liveness")
def liveness():
return {"status": "ok"}


@app.get("/readiness")
@app.post("/readiness")
def readiness():
return {"status": "ok"}


@app.get("/healthz", summary="Check server health")
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
async def healthcheck(request: Request):
return JSONResponse({"message": "Ok"}, status_code=200)


@app.websocket("/pd_master_register")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
client_ip, client_port = websocket.client
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
registered_pd_master_obj: PD_Master_Obj = pickle.loads(await websocket.receive_bytes())
logger.info(f"recieved registered_pd_master_obj {registered_pd_master_obj}")
with registered_pd_master_obj_lock:
registered_pd_master_objs[registered_pd_master_obj.node_id] = registered_pd_master_obj

try:
while True:
data = await websocket.receive_text()
assert data == "heartbeat"
except (WebSocketDisconnect, Exception, RuntimeError) as e:
logger.error(f"registered_pd_master_obj {registered_pd_master_obj} has error {str(e)}")
logger.exception(str(e))
finally:
logger.error(f"registered_pd_master_obj {registered_pd_master_obj} removed")
with registered_pd_master_obj_lock:
registered_pd_master_objs.pop(registered_pd_master_obj.node_id, None)
return


@app.get("/registered_objects")
async def get_registered_objects():
with registered_pd_master_obj_lock:
serialized_data = pickle.dumps(registered_pd_master_objs)
base64_encoded = base64.b64encode(serialized_data).decode("utf-8")
return {"data": base64_encoded}


@app.get("/allocate_global_unique_id_range")
async def allocate_global_id_range():
"""
Allocate a global ID range for the requesting client without requiring parameters.

Returns:
dict: A dictionary containing the start and end of the allocated ID range.

Example HTTP client usage:
```python
response = requests.get("http://<server_address>/allocate_global_unique_id_range")
print(response.json())
```
"""
global global_req_id
range_size = 800000
with global_req_id_lock:
if global_req_id + range_size > 2 ** 63 - 1:
global_req_id = 0
start_id = global_req_id
global_req_id += range_size
end_id = global_req_id

return {"start_id": start_id, "end_id": end_id}
4 changes: 2 additions & 2 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class Req(ctypes.Structure):
_fields_ = [
("index_in_shm_mem", ctypes.c_int),
("ref_count", ctypes.c_int), # 个人不要操作这个计数 # 个人不要操作这个引用计数
("request_id", ctypes.c_int), # 引用计数
("group_req_id", ctypes.c_int),
("request_id", ctypes.c_int64), # 引用计数
("group_req_id", ctypes.c_int64),
("input_len", ctypes.c_int),
("alloc_shm_numpy_len", ctypes.c_int),
("shm_infer_released", ctypes.c_bool), # 推理进程用于标记请求对象已经被推理进程释放,router进程得到信息后亦可释放shm req对象
Expand Down
13 changes: 10 additions & 3 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ class DecodeNode(ctypes.Structure):
("ip", ctypes.c_int32 * 4),
("rpyc_port", ctypes.c_int),
("max_new_tokens", ctypes.c_int),
# 记录当前请求使用的 pd_master 节点的 id
("pd_master_node_id_high", ctypes.c_uint64),
("pd_master_node_id_low", ctypes.c_uint64),
]

def initialize(self, data_dict):
Expand All @@ -224,15 +227,19 @@ def initialize(self, data_dict):
self.rpyc_port = data_dict["rpyc_port"]
self.max_new_tokens = data_dict["max_new_tokens"]

pd_master_node_id = data_dict["pd_master_node_id"]
self.pd_master_node_id_high = (pd_master_node_id >> 64) & 0xFFFFFFFFFFFFFFFF
self.pd_master_node_id_low = pd_master_node_id & 0xFFFFFFFFFFFFFFFF

def to_dict(self):
if not self.exists:
return None
uuid_int = (self.node_id_high << 64) | self.node_id_low
return {
"node_id": uuid_int,
"node_id": ((self.node_id_high << 64) | self.node_id_low),
"ip": ".".join(str(self.ip[i]) for i in range(4)),
"rpyc_port": self.rpyc_port,
"max_new_tokens": self.max_new_tokens,
"pd_master_node_id": ((self.pd_master_node_id_high << 64) | self.pd_master_node_id_low),
}


Expand Down Expand Up @@ -264,7 +271,7 @@ class SamplingParams(ctypes.Structure):
("allowed_token_ids", AllowedTokenIds),
("stop_sequences", StopSequenceGroups),
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
("group_request_id", ctypes.c_int), # p d mode used params
("group_request_id", ctypes.c_int64), # p d mode used params
("suggested_dp_index", ctypes.c_int), # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index
("move_kv_to_decode_node", DecodeNode), # move kv to deocde node, only used in pd mode
("skip_special_tokens", ctypes.c_bool), # whether to skip special tokens when decoding
Expand Down
2 changes: 2 additions & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class StartArgs:
)
pd_master_ip: str = field(default="127.0.0.1")
pd_master_port: int = field(default=1212)
config_server_host: str = field(default=None)
config_server_port: int = field(default=None)
pd_decode_rpyc_port: int = field(default=42000)
model_name: str = field(default="default_model_name")
model_dir: Optional[str] = field(default=None)
Expand Down
Loading