Skip to content

Commit e55cbb5

Browse files
committed
fix: memory node select
1 parent cedbb50 commit e55cbb5

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22

3-
from typing import Union, List, Tuple
3+
from typing import Union, List, Tuple, Dict
44
from lightllm.server.pd_io_struct import PD_Client_Obj
55
from lightllm.server.core.objs import SamplingParams
66
from lightllm.server.multimodal_params import MultimodalParams
@@ -49,13 +49,14 @@ class MemorySelector(PDSelector):
4949
"""基于内存使用情况的选择器"""
5050

5151
async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
52-
def _get_min_node(node_infos: dict, key: str):
53-
min_node, min_node_len = None, float("inf")
54-
for node_ip, node_info in node_infos.items():
55-
if node_info[key] < min_node_len:
56-
min_node_len = node_info[key]
57-
min_node = node_ip
58-
return min_node
52+
def _get_min_node(nodes: List[PD_Client_Obj], node_infos: Dict[str, dict], key: str) -> PD_Client_Obj:
53+
min_node, min_node_value = None, float("inf")
54+
for node in nodes:
55+
if node.client_ip_port in node_infos:
56+
if node_infos[node.client_ip_port][key] < min_node_value:
57+
min_node_value = node_infos[node.client_ip_port][key]
58+
min_node = node
59+
return min_node if min_node is not None else random.choice(nodes)
5960

6061
if self.pd_manager is None:
6162
# 如果没有 PDManager 引用,回退到随机选择
@@ -68,7 +69,7 @@ def _get_min_node(node_infos: dict, key: str):
6869
# 获取负载最小的节点
6970
p_node_infos = node_infos["prefill"]
7071
d_node_infos = node_infos["decode"]
71-
p_node = _get_min_node(p_node_infos, "mem_len") or random.choice(self.prefill_nodes)
72-
d_node = _get_min_node(d_node_infos, "mem_len") or random.choice(self.decode_nodes)
72+
p_node = _get_min_node(self.prefill_nodes, p_node_infos, "mem_len")
73+
d_node = _get_min_node(self.decode_nodes, d_node_infos, "mem_len")
7374

7475
return p_node, d_node

0 commit comments

Comments
 (0)