|
1 | 1 | import random |
2 | | - |
3 | 2 | from typing import Union, List, Tuple, Dict |
4 | 3 | from lightllm.server.pd_io_struct import PD_Client_Obj |
5 | 4 | from lightllm.server.core.objs import SamplingParams |
@@ -59,27 +58,10 @@ class MemorySelector(PDSelector): |
59 | 58 | def select_p_d_node( |
60 | 59 | self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams |
61 | 60 | ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: |
62 | | - def _get_min_node(nodes: List[PD_Client_Obj], node_infos: Dict[str, dict], key: str) -> PD_Client_Obj: |
63 | | - min_node, min_node_value = None, float("inf") |
64 | | - for node in nodes: |
65 | | - if node.client_ip_port in node_infos: |
66 | | - if node_infos[node.client_ip_port][key] < min_node_value: |
67 | | - min_node_value = node_infos[node.client_ip_port][key] |
68 | | - min_node = node |
69 | | - return min_node if min_node is not None else random.choice(nodes) |
70 | | - |
71 | | - if self.pd_manager is None: |
72 | | - # 如果没有 PDManager 引用,回退到随机选择 |
73 | | - p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None |
74 | | - d_node = random.choice(self.decode_nodes) if self.decode_nodes else None |
75 | | - return p_node, d_node |
76 | | - |
77 | | - node_infos = self.pd_manager.get_predict_node_infos() |
78 | | - |
79 | | - # 获取负载最小的节点 |
80 | | - p_node_infos = node_infos["prefill"] |
81 | | - d_node_infos = node_infos["decode"] |
82 | | - p_node = _get_min_node(self.prefill_nodes, p_node_infos, "mem_len") |
83 | | - d_node = _get_min_node(self.decode_nodes, d_node_infos, "mem_len") |
| 61 | + p_node = self._importance_sampling(self.prefill_nodes) |
| 62 | + d_node = self._importance_sampling(self.decode_nodes) |
84 | 63 |
|
85 | 64 | return p_node, d_node |
| 65 | + |
| 66 | + def _importance_sampling(self, nodes: List[PD_Client_Obj]): |
| 67 | + return random.choices(nodes, weights=[max(1.0 - e.run_status.total_token_usage_rate, 0.02) for e in nodes]) |
0 commit comments