88from lightllm .server .router .dynamic_prompt .shared_arr import SharedInt
99from lightllm .utils .profile_max_tokens import get_available_gpu_memory , get_total_gpu_memory
1010from lightllm .common .kv_trans_kernel .kv_trans import kv_trans
11- from lightllm .utils .dist_utils import get_global_rank
11+ from lightllm .utils .dist_utils import get_current_rank_in_node
1212from lightllm .utils .envs_utils import get_unique_server_name , get_env_start_args
1313
1414
@@ -37,8 +37,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
3737 # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
3838 from lightllm .utils .envs_utils import get_unique_server_name
3939
40- rank_id = get_global_rank ()
41- self .shared_can_use_token_num = SharedInt (f"{ get_unique_server_name ()} _mem_manger_can_use_token_num_{ rank_id } " )
40+ rank_in_node = get_current_rank_in_node ()
41+ self .shared_can_use_token_num = SharedInt (
42+ f"{ get_unique_server_name ()} _mem_manger_can_use_token_num_{ rank_in_node } "
43+ )
4244
4345 self .shared_can_use_token_num .set_value (self .can_use_mem_size )
4446 self ._init_buffers (
@@ -83,13 +85,10 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8385 self .kv_move_buf_indexes = torch .arange (0 , max_req_total_len + 8 , dtype = torch .int64 , device = "cuda" )
8486 return
8587
86- def send_to_decode_node (self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int ):
87- """
88- dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
89- 普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
90- 被真正使用
91- """
92- assert dp_size == 1
88+ def send_to_decode_node (
89+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
90+ ):
91+ assert dp_size_in_node == 1
9392
9493 # 先将数据发送到指定的一张卡上的buffer,再发送。
9594
@@ -123,14 +122,9 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
123122 return move_buffer
124123
125124 def receive_from_prefill_node (
126- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int
125+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
127126 ):
128- """
129- dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
130- 普通tp模式下, dp_size 一定等于 1, 同时普通模式下, 这两个参数并不会
131- 被真正使用
132- """
133- assert dp_size == 1
127+ assert dp_size_in_node == 1
134128
135129 # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
136130
@@ -160,11 +154,13 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
160154 self .kv_buffer [layer_index : layer_index + 1 , token_indexes , :, :] = buffer_tensor
161155 return
162156
163- def send_to_decode_node_p2p (self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int ):
157+ def send_to_decode_node_p2p (
158+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
159+ ):
164160 """
165161 使用 p2p triton kernel 进行数据复制和传输的实现方式。
166162 """
167- assert dp_size == 1
163+ assert dp_size_in_node == 1
168164
169165 # 先将数据发送到指定的一张卡上的buffer,再发送。
170166
@@ -190,9 +186,9 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
190186 return move_buffer
191187
192188 def receive_from_prefill_node_p2p (
193- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int
189+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
194190 ):
195- assert dp_size == 1
191+ assert dp_size_in_node == 1
196192
197193 # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
198194
@@ -303,20 +299,16 @@ class ReadOnlyStaticsMemoryManager:
303299 def __init__ (self ) -> None :
304300 args = get_env_start_args ()
305301 self .global_world_size = args .tp
306- node_world_size = args .tp // args .nnodes
307- rank_start = args .node_rank * node_world_size
308- rank_end = (args .node_rank + 1 ) * node_world_size
309- self .shared_tp_infos = {
310- rank : SharedInt (f"{ get_unique_server_name ()} _mem_manger_can_use_token_num_{ rank } " )
311- for rank in range (rank_start , rank_end )
312- }
313-
314- def get_unrefed_token_num (self , dp_rank : int ):
315- args = get_env_start_args ()
316- if args .dp == 1 and args .nnodes > 1 :
317- # 兼容多机 dp size=1 的情况
318- rank_id = args .tp // args .nnodes * args .node_rank
319- return self .shared_tp_infos [rank_id ].get_value ()
320- dp_size = args .dp
321- dp_world_size = self .global_world_size // dp_size
322- return self .shared_tp_infos [dp_rank * dp_world_size ].get_value ()
302+ self .node_world_size = args .tp // args .nnodes
303+ self .dp_world_size = self .global_world_size // args .dp
304+ # 兼容多机 dp size=1 纯 tp 模式的情况
305+ self .is_multinode_tp = args .dp == 1 and args .nnodes > 1
306+ self .shared_tp_infos = [
307+ SharedInt (f"{ get_unique_server_name ()} _mem_manger_can_use_token_num_{ rank_in_node } " )
308+ for rank_in_node in range (0 , self .node_world_size , self .dp_world_size )
309+ ]
310+
311+ def get_unrefed_token_num (self , dp_rank_in_node : int ):
312+ if self .is_multinode_tp :
313+ return self .shared_tp_infos [0 ].get_value ()
314+ return self .shared_tp_infos [dp_rank_in_node ].get_value ()
0 commit comments