Skip to content

Commit db83cce

Browse files
authored
support more PD node select func. such as random or roundrobin. (#1018)
1 parent 41b9193 commit db83cce

File tree

6 files changed

+196
-34
lines changed

6 files changed

+196
-34
lines changed

lightllm/server/api_cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
4242
default=42000,
4343
help="p d mode, decode node used for kv move manager rpyc server port",
4444
)
45+
parser.add_argument(
46+
"--select_p_d_node_strategy",
47+
type=str,
48+
default="round_robin",
49+
choices=["random", "round_robin", "adaptive_load"],
50+
help="pd master use this strategy to select p d node, can be round_robin, random or adaptive_load",
51+
)
4552
parser.add_argument(
4653
"--config_server_host",
4754
type=str,

lightllm/server/httpserver/pd_loop.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,5 +180,29 @@ async def _pd_process_generate(
180180
async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket):
181181
while True:
182182
handle_list = await forwarding_queue.wait_to_get_all_data()
183+
183184
if handle_list:
184-
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list)))
185+
load_info: dict = _get_load_info()
186+
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)))
187+
188+
189+
# 获取节点负载信息
190+
def _get_load_info() -> dict:
191+
192+
from lightllm.server.api_http import g_objs
193+
194+
assert g_objs.shared_token_load is not None, "shared_token_load is not initialized"
195+
args = g_objs.args
196+
dp_size_in_node = max(1, args.dp // args.nnodes)
197+
198+
# 获取当前每个 dp 的负载,数值含义为当前的 token 总容量使用率, 上报给 PD_Master 用于做
199+
# 调度决策。
200+
current_load = [
201+
float(g_objs.shared_token_load.get_dynamic_max_load(dp_index)) for dp_index in range(dp_size_in_node)
202+
]
203+
mean_node_load = sum(current_load) / len(current_load)
204+
load_info = {
205+
"total_token_usage_rate": mean_node_load,
206+
"client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}",
207+
}
208+
return load_info

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pickle
1313

1414
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
15-
from typing import Union, List, Tuple, Dict
15+
from typing import Union, List, Tuple, Dict, Optional
1616
from lightllm.server.core.objs import FinishStatus
1717
from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType
1818
from lightllm.server.core.objs import SamplingParams
@@ -25,6 +25,7 @@
2525
from lightllm.utils.statics_utils import MovingAverage
2626
from lightllm.server.httpserver.manager import AsyncQueue
2727
from lightllm.utils.error_utils import ServerBusyError
28+
from .pd_selector import create_selector
2829

2930
logger = init_logger(__name__)
3031

@@ -38,9 +39,8 @@ def __init__(
3839
self.args = args
3940
self.metric_client = MetricClient(metric_port)
4041
self.id_gen = ReqIDGenerator()
41-
self.prefill_nodes: List[PD_Client_Obj] = []
42-
self.decode_nodes: List[PD_Client_Obj] = []
43-
self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {}
42+
43+
self.pd_manager = PDManager(args)
4444

4545
self.req_id_to_out_inf: Dict[int, ReqStatus] = {}
4646
self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对
@@ -52,30 +52,11 @@ def __init__(
5252
return
5353

5454
async def register_pd(self, pd_info_json, websocket):
55-
pd_client = PD_Client_Obj(**pd_info_json)
56-
pd_client.websocket = websocket
57-
self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client
58-
if pd_client.mode == "prefill":
59-
self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
60-
self.prefill_nodes.append(pd_client)
61-
elif pd_client.mode == "decode":
62-
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]
63-
self.decode_nodes.append(pd_client)
64-
else:
65-
assert False
66-
67-
logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed")
55+
self.pd_manager.register_pd(pd_info_json, websocket)
6856
return
6957

7058
async def remove_pd(self, pd_info_json):
71-
pd_client = PD_Client_Obj(**pd_info_json)
72-
try:
73-
del self.url_to_pd_nodes[pd_client.client_ip_port]
74-
except:
75-
pass
76-
self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
77-
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]
78-
logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed")
59+
self.pd_manager.remove_pd(pd_info_json)
7960
return
8061

8162
async def update_req_status(self, upkv_status: UpKVStatus):
@@ -108,11 +89,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
10889
async def select_p_d_node(
10990
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
11091
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
111-
import random
112-
113-
p_node = random.choice(self.prefill_nodes)
114-
d_node = random.choice(self.decode_nodes)
115-
return p_node, d_node
92+
return self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params)
11693

11794
async def generate(
11895
self,
@@ -264,7 +241,7 @@ async def _wait_to_token_package(
264241
request: Request,
265242
):
266243
out_token_counter = 0
267-
first_token_cost_ms = sys.float_info.max
244+
first_token_cost_ms = float("inf")
268245
group_request_id = sampling_params.group_request_id
269246
unfinished_count = sampling_params.best_of
270247
is_first_token = True
@@ -368,7 +345,10 @@ async def handle_loop(self):
368345
try:
369346
for obj in objs:
370347
if obj[0] == ObjType.TOKEN_PACKS:
371-
for sub_req_id, text, metadata, finish_status in obj[1]:
348+
token_list, node_load_info = obj[1], obj[2]
349+
self.pd_manager.update_node_load_info(node_load_info)
350+
351+
for sub_req_id, text, metadata, finish_status in token_list:
372352
finish_status: FinishStatus = finish_status
373353
group_req_id = convert_sub_id_to_group_id(sub_req_id)
374354
try:
@@ -415,3 +395,69 @@ async def pop_all_tokens(self):
415395
ans = self.out_token_info_list.copy()
416396
self.out_token_info_list.clear()
417397
return ans
398+
399+
400+
class PDManager:
401+
def __init__(self, args):
402+
self.args = args
403+
self.prefill_nodes: List[PD_Client_Obj] = []
404+
self.decode_nodes: List[PD_Client_Obj] = []
405+
self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {}
406+
self.selector = create_selector(args.select_p_d_node_strategy, self)
407+
return
408+
409+
def register_pd(self, pd_info_json, websocket):
410+
pd_client = PD_Client_Obj(**pd_info_json)
411+
pd_client.websocket = websocket
412+
self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client
413+
414+
if pd_client.mode == "prefill":
415+
self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
416+
self.prefill_nodes.append(pd_client)
417+
elif pd_client.mode == "decode":
418+
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]
419+
self.decode_nodes.append(pd_client)
420+
else:
421+
assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}"
422+
423+
self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)
424+
425+
logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed")
426+
return
427+
428+
def remove_pd(self, pd_info_json):
429+
pd_client = PD_Client_Obj(**pd_info_json)
430+
431+
self.url_to_pd_nodes.pop(pd_client.client_ip_port, None)
432+
self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
433+
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]
434+
435+
self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)
436+
437+
logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed")
438+
return
439+
440+
def update_node_load_info(self, load_info: Optional[dict]):
441+
"""更新节点负载信息
442+
load_info: 节点负载信息字典,内容格式如下,可以为 None
443+
{
444+
"total_token_usage_rate": xxxx,
445+
"client_ip_port": xxxx,
446+
}
447+
"""
448+
try:
449+
if load_info is None:
450+
return
451+
client_ip_port = load_info["client_ip_port"]
452+
total_token_usage_rate = load_info["total_token_usage_rate"]
453+
pd_client = self.url_to_pd_nodes.get(client_ip_port)
454+
pd_client.run_status.total_token_usage_rate = total_token_usage_rate
455+
except BaseException as e:
456+
logger.warning(f"udpate node load info failed, load_info: {load_info} error: {str(e)}")
457+
return
458+
459+
def select_p_d_node(
460+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
461+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
462+
p_node, d_node = self.selector.select_p_d_node(prompt, sampling_params, multimodal_params)
463+
return p_node, d_node
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .pd_selector import PDSelector, RandomSelector, RoundRobinSelector, AdaptiveLoadSelector
2+
3+
4+
def create_selector(selector_type: str, pd_manager) -> PDSelector:
5+
if selector_type == "random":
6+
return RandomSelector(pd_manager)
7+
elif selector_type == "round_robin":
8+
return RoundRobinSelector(pd_manager)
9+
elif selector_type == "adaptive_load":
10+
return AdaptiveLoadSelector(pd_manager)
11+
else:
12+
raise ValueError(f"Invalid selector type: {selector_type}")
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import random
2+
from typing import Union, List, Tuple, Dict
3+
from lightllm.server.pd_io_struct import PD_Client_Obj
4+
from lightllm.server.core.objs import SamplingParams
5+
from lightllm.server.multimodal_params import MultimodalParams
6+
7+
8+
class PDSelector:
9+
def __init__(self, pd_manager):
10+
self.prefill_nodes: List[PD_Client_Obj] = []
11+
self.decode_nodes: List[PD_Client_Obj] = []
12+
self.pd_manager = pd_manager
13+
14+
def update_nodes(self, prefill_nodes, decode_nodes):
15+
self.prefill_nodes = prefill_nodes
16+
self.decode_nodes = decode_nodes
17+
18+
def select_p_d_node(
19+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
20+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
21+
raise NotImplementedError("Subclass must implement this method")
22+
23+
24+
class RandomSelector(PDSelector):
25+
"""随机选择器"""
26+
27+
def select_p_d_node(
28+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
29+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
30+
p_node = random.choice(self.prefill_nodes)
31+
d_node = random.choice(self.decode_nodes)
32+
return p_node, d_node
33+
34+
35+
class RoundRobinSelector(PDSelector):
36+
"""轮询选择器"""
37+
38+
def __init__(self, pd_manager):
39+
super().__init__(pd_manager)
40+
self.prefill_node_index: int = 0
41+
self.decode_node_index: int = 0
42+
43+
def select_p_d_node(
44+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
45+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
46+
self.prefill_node_index = self.prefill_node_index % len(self.prefill_nodes)
47+
self.decode_node_index = self.decode_node_index % len(self.decode_nodes)
48+
p_node = self.prefill_nodes[self.prefill_node_index]
49+
d_node = self.decode_nodes[self.decode_node_index]
50+
self.prefill_node_index += 1
51+
self.decode_node_index += 1
52+
return p_node, d_node
53+
54+
55+
class AdaptiveLoadSelector(PDSelector):
56+
"""基于负载使用情况的选择器"""
57+
58+
def select_p_d_node(
59+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
60+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
61+
p_node = self._importance_sampling(self.prefill_nodes)
62+
d_node = self._importance_sampling(self.decode_nodes)
63+
64+
return p_node, d_node
65+
66+
def _importance_sampling(self, nodes: List[PD_Client_Obj]):
67+
return random.choices(nodes, weights=[max(1.0 - e.run_status.total_token_usage_rate, 0.02) for e in nodes])

lightllm/server/pd_io_struct.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import enum
22
import time
3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from typing import Dict, List, Optional, Tuple, Union
55
from lightllm.server.req_id_generator import convert_sub_id_to_group_id
66
from fastapi import WebSocket
@@ -38,13 +38,19 @@ class ObjType(enum.Enum):
3838
TOKEN_PACKS = 3
3939

4040

41+
@dataclass
42+
class _PD_Client_RunStatus:
43+
total_token_usage_rate: float = 0.0 # pd 节点上的 token 使用率
44+
45+
4146
@dataclass
4247
class PD_Client_Obj:
4348
node_id: int
4449
client_ip_port: str
4550
mode: str # 只能是 prefill 或者 decode 节点
4651
start_args: object # 节点的启动参数信息,用于做匹配性的校验,防止运行过程中出现问题。
4752
websocket: WebSocket = None # 用于通信的 websocket 连接对象
53+
run_status: _PD_Client_RunStatus = field(default_factory=_PD_Client_RunStatus)
4854

4955
def __post_init__(self):
5056
if self.mode not in ["prefill", "decode"]:

0 commit comments

Comments
 (0)