Skip to content

Commit 2990df8

Browse files
committed
add global req_id alloc.
1 parent f3f29e6 commit 2990df8

File tree

5 files changed

+81
-11
lines changed

5 files changed

+81
-11
lines changed

lightllm/server/config_server/api_http.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
1+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query
22
from threading import Lock
33
from typing import Dict
44
from fastapi.responses import JSONResponse
@@ -7,13 +7,17 @@
77
import base64
88
import pickle
99
import os
10+
import requests
1011

1112
logger = init_logger(__name__)
1213
app = FastAPI()
1314

1415
registered_pd_master_objs:Dict[str, PD_Master_Obj] = {}
1516
registered_pd_master_obj_lock = Lock()
1617

18+
global_req_id = 0
19+
global_req_id_lock = Lock()
20+
1721

1822
@app.get("/liveness")
1923
@app.post("/liveness")
@@ -41,7 +45,6 @@ async def websocket_endpoint(websocket: WebSocket):
4145
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
4246
registered_pd_master_obj: PD_Master_Obj = pickle.loads(await websocket.receive_bytes())
4347
logger.info(f"recieved registered_pd_master_obj {registered_pd_master_obj}")
44-
4548
with registered_pd_master_obj_lock:
4649
registered_pd_master_objs[registered_pd_master_obj.node_id] = registered_pd_master_obj
4750

@@ -64,3 +67,28 @@ async def get_registered_objects():
6467
serialized_data = pickle.dumps(registered_pd_master_objs)
6568
base64_encoded = base64.b64encode(serialized_data).decode('utf-8')
6669
return {"data": base64_encoded}
70+
71+
@app.get("/allocate_global_unique_id_range")
72+
async def allocate_global_id_range():
73+
"""
74+
Allocate a global ID range for the requesting client without requiring parameters.
75+
76+
Returns:
77+
dict: A dictionary containing the start and end of the allocated ID range.
78+
79+
Example HTTP client usage:
80+
```python
81+
response = requests.get("http://<server_address>/allocate_global_unique_id_range")
82+
print(response.json())
83+
```
84+
"""
85+
global global_req_id
86+
range_size = 800000
87+
with global_req_id_lock:
88+
if req_id + range_size > 2**63 - 1:
89+
req_id = 0
90+
start_id = req_id
91+
req_id += range_size
92+
end_id = req_id
93+
94+
return {"start_id": start_id, "end_id": end_id}

lightllm/server/core/objs/start_args_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class StartArgs:
1515
)
1616
pd_master_ip: str = field(default="127.0.0.1")
1717
pd_master_port: int = field(default=1212)
18+
config_server_host: str = field(default=None)
19+
config_server_port: int = field(default=None)
1820
pd_decode_rpyc_port: int = field(default=42000)
1921
model_name: str = field(default="default_model_name")
2022
model_dir: Optional[str] = field(default=None)

lightllm/server/req_id_generator.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
import threading
1+
import time
2+
import requests
23
import numpy as np
4+
from lightllm.utils.log_utils import init_logger
5+
6+
logger = init_logger(__name__)
37

48
# 可以支持的最大 beam 参数上限,为了让生成的请求的group_req_id 和 sub_req_id 可以有直接的计算映射关系
59
# id 生成器,只会以 MAX_BEST_OF 的间隔生成id 作为 group_req_id, (sub_req_id // MAX_BEST_OF * MAX_BEST_OF) 即可
@@ -12,15 +16,44 @@ class ReqIDGenerator:
1216
def __init__(self):
1317
from lightllm.server.core.objs.atomic_lock import AtomicShmLock
1418
from lightllm.server.core.objs.shm_array import ShmArray
15-
from lightllm.utils.envs_utils import get_unique_server_name
16-
17-
self.current_id = ShmArray(f"{get_unique_server_name()}_req_id_gen", (1,), dtype=np.int64)
19+
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
20+
21+
self.args = get_env_start_args()
22+
self.use_config_server = self.args.config_server_host and self.args.config_server_port
23+
self.current_id = ShmArray(f"{get_unique_server_name()}_req_id_gen", (2,), dtype=np.int64)
1824
self.current_id.create_shm()
1925
self.current_id.arr[0] = 0
26+
self.current_id.arr[1] = 0
2027
self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock")
2128

29+
def _check_and_set_new_id_range(self):
30+
need_update_range = self.current_id.arr[0] + MAX_BEST_OF >= self.current_id.arr[1]
31+
if need_update_range:
32+
if not self.use_config_server:
33+
self.current_id.arr[0] = MAX_BEST_OF
34+
self.current_id.arr[1] = np.iinfo(np.int64).max
35+
else:
36+
while True:
37+
try:
38+
url = f"http://{self.args.config_server_host}:{self.args.config_server_port}/allocate_global_unique_id_range"
39+
response = requests.get(url)
40+
if response.status_code == 200:
41+
id_range = response.json()
42+
logger.info(f"get new id range {id_range}")
43+
# 保证id满足倍乘关系
44+
self.current_id.arr[0] = (id_range["start_id"] // MAX_BEST_OF + 1) * MAX_BEST_OF
45+
self.current_id.arr[1] = id_range["end_id"]
46+
assert self.current_id.arr[0] + MAX_BEST_OF < self.current_id.arr[1], f"get id range error {self.current_id.arr[0]} {self.current_id.arr[1]}"
47+
return
48+
else:
49+
raise RuntimeError(f"Failed to fetch ID range from config server: {response.status_code}")
50+
except BaseException as e:
51+
logger.exception(str(e))
52+
time.sleep(3)
53+
2254
def generate_id(self):
2355
with self.lock:
56+
self._check_and_set_new_id_range()
2457
id = self.current_id.arr[0]
2558
self.current_id.arr[0] += MAX_BEST_OF
2659
return id

lightllm/utils/envs_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111

1212
def set_unique_server_name(args):
13-
os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank)
13+
if args.run_mode == "pd_master":
14+
os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master"
15+
else:
16+
os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank)
1417
return
1518

1619

lightllm/utils/health_check.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@
1010
from fastapi import Request
1111
from lightllm.server.req_id_generator import ReqIDGenerator
1212
from lightllm.utils.log_utils import init_logger
13-
from lightllm.utils.envs_utils import get_unique_server_name
13+
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1414

1515
logger = init_logger(__name__)
1616

1717

1818
_g_health_req_id_gen = ReqIDGenerator()
19-
_g_health_req_id_gen.generate_id()
20-
2119

2220
@dataclass
2321
class HealthObj:
@@ -78,7 +76,13 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req
7876
sampling_params = SamplingParams()
7977
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
8078
sampling_params.verify()
81-
sampling_params.group_request_id = -_g_health_req_id_gen.generate_id() # health monitor 的 id 是负的
79+
80+
if get_env_start_args().run_mode == "pd_master":
81+
# Since the id assigned by pd master needs to be passed to prefill and decode nodes for inference,
82+
# a normal request id is required instead of a negative id.
83+
sampling_params.group_request_id = _g_health_req_id_gen.generate_id()
84+
else:
85+
sampling_params.group_request_id = -_g_health_req_id_gen.generate_id() # health monitor 的 id 是负的
8286
multimodal_params_dict = request_dict.get("multimodal_params", {})
8387
multimodal_params = MultimodalParams(**multimodal_params_dict)
8488
results_generator = httpserver_manager.generate(

0 commit comments

Comments
 (0)