Skip to content

Commit 7130a8e

Browse files
committed
fix
1 parent cb23ae0 commit 7130a8e

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def __init__(
5555
return
5656

5757
async def register_pd(self, pd_info_json, websocket):
58-
await self.pd_manager.register_pd(pd_info_json, websocket)
58+
self.pd_manager.register_pd(pd_info_json, websocket)
5959
return
6060

6161
async def remove_pd(self, pd_info_json):
62-
await self.pd_manager.remove_pd(pd_info_json)
62+
self.pd_manager.remove_pd(pd_info_json)
6363
return
6464

6565
async def update_req_status(self, upkv_status: UpKVStatus):
@@ -92,7 +92,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
9292
async def select_p_d_node(
9393
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
9494
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
95-
return await self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params)
95+
return self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params)
9696

9797
async def generate(
9898
self,
@@ -411,7 +411,7 @@ def __init__(self, args):
411411
)
412412
return
413413

414-
async def register_pd(self, pd_info_json, websocket):
414+
def register_pd(self, pd_info_json, websocket):
415415
pd_client = PD_Client_Obj(**pd_info_json)
416416
pd_client.websocket = websocket
417417
self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client
@@ -425,19 +425,19 @@ async def register_pd(self, pd_info_json, websocket):
425425
else:
426426
assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}"
427427

428-
await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)
428+
self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)
429429

430430
logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed")
431431
return
432432

433-
async def remove_pd(self, pd_info_json):
433+
def remove_pd(self, pd_info_json):
434434
pd_client = PD_Client_Obj(**pd_info_json)
435435

436436
self.url_to_pd_nodes.pop(pd_client.client_ip_port, None)
437437
self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
438438
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]
439439

440-
await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)
440+
self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)
441441

442442
logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed")
443443
return
@@ -461,13 +461,8 @@ def update_node_load_info(self, load_info: Optional[dict]):
461461
logger.warning(f"udpate node load info failed, load_info: {load_info} error: {str(e)}")
462462
return
463463

464-
def get_predict_node_infos(self):
465-
"""获取所有节点的预测负载信息"""
466-
return self.node_info_recorder.get_predict_node_infos()
467-
468-
async def select_p_d_node(
464+
def select_p_d_node(
469465
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
470466
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
471-
p_node, d_node = await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params)
472-
self.node_info_recorder.update_predict_node_info(p_node, d_node, prompt, sampling_params, multimodal_params)
467+
p_node, d_node = self.selector.select_p_d_node(prompt, sampling_params, multimodal_params)
473468
return p_node, d_node

lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,22 @@ def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Cli
1212
self.decode_nodes: List[PD_Client_Obj] = decode_nodes
1313
self.pd_manager = pd_manager
1414

15-
async def update_nodes(self, prefill_nodes, decode_nodes):
15+
def update_nodes(self, prefill_nodes, decode_nodes):
1616
self.prefill_nodes = prefill_nodes
1717
self.decode_nodes = decode_nodes
1818

19-
async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
19+
def select_p_d_node(
20+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
21+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
2022
raise NotImplementedError("Subclass must implement this method")
2123

2224

2325
class RandomSelector(PDSelector):
2426
"""随机选择器"""
2527

26-
async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
28+
def select_p_d_node(
29+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
30+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
2731
p_node = random.choice(self.prefill_nodes)
2832
d_node = random.choice(self.decode_nodes)
2933
return p_node, d_node
@@ -37,7 +41,9 @@ def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Cli
3741
self.prefill_node_index: int = 0
3842
self.decode_node_index: int = 0
3943

40-
async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
44+
def select_p_d_node(
45+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
46+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
4147
p_node = self.prefill_nodes[self.prefill_node_index]
4248
d_node = self.decode_nodes[self.decode_node_index]
4349
self.prefill_node_index = (self.prefill_node_index + 1) % len(self.prefill_nodes)
@@ -48,7 +54,9 @@ async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params:
4854
class MemorySelector(PDSelector):
4955
"""基于内存使用情况的选择器"""
5056

51-
async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
57+
def select_p_d_node(
58+
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
59+
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
5260
def _get_min_node(nodes: List[PD_Client_Obj], node_infos: Dict[str, dict], key: str) -> PD_Client_Obj:
5361
min_node, min_node_value = None, float("inf")
5462
for node in nodes:

0 commit comments

Comments
 (0)