Skip to content

Commit f0bf210

Browse files
committed
fix
1 parent 2ac5d72 commit f0bf210

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import random
2-
32
from typing import Union, List, Tuple, Dict
43
from lightllm.server.pd_io_struct import PD_Client_Obj
54
from lightllm.server.core.objs import SamplingParams
@@ -59,27 +58,10 @@ class MemorySelector(PDSelector):
5958
def select_p_d_node(
6059
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
6160
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
62-
def _get_min_node(nodes: List[PD_Client_Obj], node_infos: Dict[str, dict], key: str) -> PD_Client_Obj:
63-
min_node, min_node_value = None, float("inf")
64-
for node in nodes:
65-
if node.client_ip_port in node_infos:
66-
if node_infos[node.client_ip_port][key] < min_node_value:
67-
min_node_value = node_infos[node.client_ip_port][key]
68-
min_node = node
69-
return min_node if min_node is not None else random.choice(nodes)
70-
71-
if self.pd_manager is None:
72-
# 如果没有 PDManager 引用,回退到随机选择
73-
p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None
74-
d_node = random.choice(self.decode_nodes) if self.decode_nodes else None
75-
return p_node, d_node
76-
77-
node_infos = self.pd_manager.get_predict_node_infos()
78-
79-
# 获取负载最小的节点
80-
p_node_infos = node_infos["prefill"]
81-
d_node_infos = node_infos["decode"]
82-
p_node = _get_min_node(self.prefill_nodes, p_node_infos, "mem_len")
83-
d_node = _get_min_node(self.decode_nodes, d_node_infos, "mem_len")
61+
p_node = self._importance_sampling(self.prefill_nodes)
62+
d_node = self._importance_sampling(self.decode_nodes)
8463

8564
return p_node, d_node
65+
66+
def _importance_sampling(self, nodes: List[PD_Client_Obj]):
67+
return random.choices(nodes, weights=[max(1.0 - e.run_status.total_token_usage_rate, 0.02) for e in nodes])

0 commit comments

Comments
 (0)