11import random
22
3- from typing import Union , List , Tuple
3+ from typing import Union , List , Tuple , Dict
44from lightllm .server .pd_io_struct import PD_Client_Obj
55from lightllm .server .core .objs import SamplingParams
66from lightllm .server .multimodal_params import MultimodalParams
@@ -49,13 +49,14 @@ class MemorySelector(PDSelector):
4949 """基于内存使用情况的选择器"""
5050
5151 async def select_p_d_node (self , prompt : Union [str , List [int ]], sampling_params : SamplingParams , multimodal_params : MultimodalParams ) -> Tuple [PD_Client_Obj , PD_Client_Obj ]:
52- def _get_min_node (node_infos : dict , key : str ):
53- min_node , min_node_len = None , float ("inf" )
54- for node_ip , node_info in node_infos .items ():
55- if node_info [key ] < min_node_len :
56- min_node_len = node_info [key ]
57- min_node = node_ip
58- return min_node
52+ def _get_min_node (nodes : List [PD_Client_Obj ], node_infos : Dict [str , dict ], key : str ) -> PD_Client_Obj :
53+ min_node , min_node_value = None , float ("inf" )
54+ for node in nodes :
55+ if node .client_ip_port in node_infos :
56+ if node_infos [node .client_ip_port ][key ] < min_node_value :
57+ min_node_value = node_infos [node .client_ip_port ][key ]
58+ min_node = node
59+ return min_node if min_node is not None else random .choice (nodes )
5960
6061 if self .pd_manager is None :
6162 # 如果没有 PDManager 引用,回退到随机选择
@@ -68,7 +69,7 @@ def _get_min_node(node_infos: dict, key: str):
6869 # 获取负载最小的节点
6970 p_node_infos = node_infos ["prefill" ]
7071 d_node_infos = node_infos ["decode" ]
71- p_node = _get_min_node (p_node_infos , "mem_len" ) or random . choice ( self . prefill_nodes )
72- d_node = _get_min_node (d_node_infos , "mem_len" ) or random . choice ( self . decode_nodes )
72+ p_node = _get_min_node (self . prefill_nodes , p_node_infos , "mem_len" )
73+ d_node = _get_min_node (self . decode_nodes , d_node_infos , "mem_len" )
7374
7475 return p_node , d_node
0 commit comments