1010from lightllm .common .kv_trans_kernel .kv_trans import kv_trans
1111from lightllm .utils .dist_utils import get_current_rank_in_node
1212from lightllm .utils .envs_utils import get_unique_server_name , get_env_start_args
13+ from lightllm .distributed .pynccl import PyNcclCommunicator
1314
1415
1516logger = init_logger (__name__ )
@@ -86,7 +87,8 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8687 return
8788
8889 def send_to_decode_node (
89- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
90+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int ,
91+ nccl_comm : PyNcclCommunicator
9092 ):
9193 assert dp_size_in_node == 1
9294
@@ -103,14 +105,14 @@ def send_to_decode_node(
103105 for layer_index in range (mem .layer_num ):
104106 move_buffer = mem ._get_kv_move_data (move_token_indexes , layer_index )
105107 if i == cur_device_index :
106- dist .send (move_buffer , dst = 1 )
108+ nccl_comm .send (move_buffer , dst = 1 )
107109 else :
108110 move_size = move_buffer .numel ()
109111 new_move_buffer = cur_mem .kv_move_buffer .view (- 1 )[0 :move_size ].view (move_buffer .shape )
110112 from torch .cuda import comm
111113
112114 comm .broadcast (move_buffer , out = [new_move_buffer ])
113- dist .send (new_move_buffer , dst = 1 )
115+ nccl_comm .send (new_move_buffer , dst = 1 )
114116 return
115117
116118 def _get_kv_move_data (self , token_indexes : List [int ], layer_index : int ):
@@ -122,7 +124,8 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
122124 return move_buffer
123125
124126 def receive_from_prefill_node (
125- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
127+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int ,
128+ nccl_comm : PyNcclCommunicator ,
126129 ):
127130 assert dp_size_in_node == 1
128131
@@ -139,7 +142,7 @@ def receive_from_prefill_node(
139142 recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (1 , token_num , 2 * self .head_num , self .head_dim )
140143 for i , mem in enumerate (mem_managers ):
141144 for layer_index in range (mem .layer_num ):
142- dist .recv (recive_buffer , src = 0 )
145+ nccl_comm .recv (recive_buffer , src = 0 )
143146 if i == cur_device_index :
144147 mem ._write_kv_move_data (move_token_indexes , recive_buffer , layer_index )
145148 else :
@@ -155,7 +158,8 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
155158 return
156159
157160 def send_to_decode_node_p2p (
158- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
161+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int ,
162+ nccl_comm : PyNcclCommunicator
159163 ):
160164 """
161165 使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -173,7 +177,7 @@ def send_to_decode_node_p2p(
173177 for i , mem in enumerate (mem_managers ):
174178 for layer_index in range (mem .layer_num ):
175179 move_buffer = mem ._get_kv_move_data_p2p (move_token_indexes , layer_index , self .kv_move_buffer )
176- dist .send (move_buffer , dst = 1 )
180+ nccl_comm .send (move_buffer , dst = 1 )
177181 return
178182
179183 def _get_kv_move_data_p2p (self , token_indexes : torch .Tensor , layer_index : int , kv_move_buffer : torch .Tensor ):
@@ -186,7 +190,8 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
186190 return move_buffer
187191
188192 def receive_from_prefill_node_p2p (
189- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
193+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int ,
194+ nccl_comm : PyNcclCommunicator ,
190195 ):
191196 assert dp_size_in_node == 1
192197
@@ -204,7 +209,7 @@ def receive_from_prefill_node_p2p(
204209 recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (token_num , 2 * self .head_num , self .head_dim )
205210 for i , mem in enumerate (mem_managers ):
206211 for layer_index in range (mem .layer_num ):
207- dist .recv (recive_buffer , src = 0 )
212+ nccl_comm .recv (recive_buffer , src = 0 )
208213 mem ._write_kv_move_data_p2p (move_token_indexes , recive_buffer , layer_index )
209214 return
210215
0 commit comments