11# 该文件用于提供在数据dp并行的推理模式下,共享kv cache trans相关的功能函数模块
2+ import time
23import numpy as np
34import dataclasses
45import torch
910from lightllm .server .core .objs .shm_array import ShmArray
1011from ...infer_batch import InferReq
1112from lightllm .utils .dist_utils import get_current_device_id
13+ from lightllm .server .router .model_infer .infer_batch import g_infer_context
14+ import torch .distributed as dist
1215
1316
1417class DPKVSharedMoudle :
@@ -36,7 +39,7 @@ def fill_reqs_info(self, reqs: List[InferReq]):
3639 """
3740 填充请求的 kv 信息到共享内存中
3841 """
39- self .backend .node_nccl_group . barrier ( )
42+ dist . barrier ( group = self .backend .node_nccl_group )
4043 if self .backend .is_master_in_dp :
4144 self .shared_req_infos .arr [0 : len (reqs ), self .dp_rank_in_node , self ._KV_LEN_INDEX ] = [
4245 req .cur_kv_len for req in reqs
@@ -54,9 +57,7 @@ def build_shared_kv_trans_tasks(
5457 """
5558 构建共享kv交换信息
5659 """
57- from lightllm .server .router .model_infer .infer_batch import g_infer_context
58-
59- self .backend .node_nccl_group .barrier ()
60+ dist .barrier (group = self .backend .node_nccl_group )
6061
6162 trans_tasks : List [TransTask ] = []
6263 rank_max_radix_cache_lens = np .max (
@@ -96,22 +97,20 @@ def build_shared_kv_trans_tasks(
9697 def kv_trans (self , trans_tasks : List ["TransTask" ]):
9798 from lightllm .server .router .model_infer .infer_batch import g_infer_context
9899
99- self .backend .node_nccl_group .barrier ()
100100 # kv 传输
101101 if len (trans_tasks ) > 0 :
102102 max_kv_len_mem_indexes = []
103103 max_kv_len_dp_ranks = []
104104 mem_indexes = []
105105
106106 for i , trans_task in enumerate (trans_tasks ):
107- max_kv_len_mem_indexes .extend (trans_task .max_kv_len_mem_indexes )
107+ max_kv_len_mem_indexes .append (trans_task .max_kv_len_mem_indexes )
108108 max_kv_len_dp_ranks .extend ([trans_task .max_kv_len_dp_rank ] * len (trans_task .max_kv_len_mem_indexes ))
109- mem_indexes .extend (trans_task .mem_indexes )
109+ mem_indexes .append (trans_task .mem_indexes )
110110
111- max_kv_len_mem_indexes_tensor = torch .tensor (max_kv_len_mem_indexes , dtype = torch .int64 , device = "cuda" )
111+ max_kv_len_mem_indexes_tensor = torch .cat (max_kv_len_mem_indexes ). to ( dtype = torch .int64 , device = "cuda" )
112112 max_kv_len_dp_ranks_tensor = torch .tensor (max_kv_len_dp_ranks , dtype = torch .int32 , device = "cuda" )
113- mem_indexes_tensor = torch .tensor (mem_indexes , dtype = torch .int64 , device = "cuda" )
114-
113+ mem_indexes_tensor = torch .cat (mem_indexes ).to (dtype = torch .int64 , device = "cuda" )
115114 self .backend .model .mem_manager .copy_kv_from_other_dp_ranks (
116115 mem_managers = self .backend .mem_managers ,
117116 move_token_indexes = max_kv_len_mem_indexes_tensor ,
@@ -122,7 +121,6 @@ def kv_trans(self, trans_tasks: List["TransTask"]):
122121 )
123122 self .backend .logger .info (f"dp_i { self .dp_rank_in_node } transfer kv tokens num: { len (mem_indexes_tensor )} " )
124123
125- self .backend .node_nccl_group .barrier ()
126124 for trans_task in trans_tasks :
127125 g_infer_context .req_manager .req_to_token_indexs [
128126 trans_task .req .req_idx ,
@@ -131,7 +129,6 @@ def kv_trans(self, trans_tasks: List["TransTask"]):
131129 trans_task .req .cur_kv_len += len (trans_task .mem_indexes )
132130 if self .backend .is_master_in_dp :
133131 trans_task .req .shm_req .shm_cur_kv_len = trans_task .req .cur_kv_len
134- self .backend .node_nccl_group .barrier ()
135132
136133
137134@dataclasses .dataclass
@@ -141,24 +138,3 @@ class TransTask:
141138 max_kv_len_dp_rank : int
142139 max_kv_len_mem_manager_index : int
143140 max_kv_len_mem_indexes : torch .Tensor
144-
145-
146- def init_dp_kv_shared (backend ):
147- torch .cuda .set_device (get_current_device_id ())
148-
149- backend .dp_kv_shared_moudle = DPKVSharedMoudle (
150- max_req_num = backend .args .running_max_req_size ,
151- max_req_seq_len = backend .args .max_req_total_len + 8 ,
152- dp_size_in_node = backend .dp_size_in_node ,
153- backend = backend ,
154- )
155- backend .node_nccl_group .barrier ()
156-
157- # Collect mem_managers from all ranks
158- backend .mem_managers = []
159- for rank_idx in range (backend .node_world_size ):
160- if rank_idx != backend .rank_in_node :
161- backend .mem_managers .append (MemoryManager .loads_from_shm (rank_idx , backend .rank_in_node ))
162- else :
163- backend .mem_managers .append (backend .model .mem_manager )
164- return
0 commit comments