1010from lightllm .common .kv_trans_kernel .kv_trans import kv_trans
1111from lightllm .utils .dist_utils import get_current_rank_in_node
1212from lightllm .utils .envs_utils import get_unique_server_name , get_env_start_args
13+ from lightllm .distributed .pynccl import PyNcclCommunicator
1314
1415
1516logger = init_logger (__name__ )
@@ -91,7 +92,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
9192 return
9293
9394 def send_to_decode_node (
94- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
95+ self ,
96+ move_tasks : List [KVMoveTask ],
97+ mem_managers : List ["MemoryManager" ],
98+ dp_size_in_node : int ,
99+ nccl_comm : PyNcclCommunicator ,
95100 ):
96101 assert dp_size_in_node == 1
97102
@@ -108,14 +113,14 @@ def send_to_decode_node(
108113 for layer_index in range (mem .layer_num ):
109114 move_buffer = mem ._get_kv_move_data (move_token_indexes , layer_index )
110115 if i == cur_device_index :
111- dist .send (move_buffer , dst = 1 )
116+ nccl_comm .send (move_buffer , dst = 1 )
112117 else :
113118 move_size = move_buffer .numel ()
114119 new_move_buffer = cur_mem .kv_move_buffer .view (- 1 )[0 :move_size ].view (move_buffer .shape )
115120 from torch .cuda import comm
116121
117122 comm .broadcast (move_buffer , out = [new_move_buffer ])
118- dist .send (new_move_buffer , dst = 1 )
123+ nccl_comm .send (new_move_buffer , dst = 1 )
119124 return
120125
121126 def _get_kv_move_data (self , token_indexes : List [int ], layer_index : int ):
@@ -127,7 +132,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
127132 return move_buffer
128133
129134 def receive_from_prefill_node (
130- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
135+ self ,
136+ move_tasks : List [KVMoveTask ],
137+ mem_managers : List ["MemoryManager" ],
138+ dp_size_in_node : int ,
139+ nccl_comm : PyNcclCommunicator ,
131140 ):
132141 assert dp_size_in_node == 1
133142
@@ -144,7 +153,7 @@ def receive_from_prefill_node(
144153 recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (1 , token_num , 2 * self .head_num , self .head_dim )
145154 for i , mem in enumerate (mem_managers ):
146155 for layer_index in range (mem .layer_num ):
147- dist .recv (recive_buffer , src = 0 )
156+ nccl_comm .recv (recive_buffer , src = 0 )
148157 if i == cur_device_index :
149158 mem ._write_kv_move_data (move_token_indexes , recive_buffer , layer_index )
150159 else :
@@ -160,7 +169,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
160169 return
161170
162171 def send_to_decode_node_p2p (
163- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
172+ self ,
173+ move_tasks : List [KVMoveTask ],
174+ mem_managers : List ["MemoryManager" ],
175+ dp_size_in_node : int ,
176+ nccl_comm : PyNcclCommunicator ,
164177 ):
165178 """
166179 使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -178,7 +191,7 @@ def send_to_decode_node_p2p(
178191 for i , mem in enumerate (mem_managers ):
179192 for layer_index in range (mem .layer_num ):
180193 move_buffer = mem ._get_kv_move_data_p2p (move_token_indexes , layer_index , self .kv_move_buffer )
181- dist .send (move_buffer , dst = 1 )
194+ nccl_comm .send (move_buffer , dst = 1 )
182195 return
183196
184197 def _get_kv_move_data_p2p (self , token_indexes : torch .Tensor , layer_index : int , kv_move_buffer : torch .Tensor ):
@@ -191,7 +204,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
191204 return move_buffer
192205
193206 def receive_from_prefill_node_p2p (
194- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
207+ self ,
208+ move_tasks : List [KVMoveTask ],
209+ mem_managers : List ["MemoryManager" ],
210+ dp_size_in_node : int ,
211+ nccl_comm : PyNcclCommunicator ,
195212 ):
196213 assert dp_size_in_node == 1
197214
@@ -209,7 +226,7 @@ def receive_from_prefill_node_p2p(
209226 recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (token_num , 2 * self .head_num , self .head_dim )
210227 for i , mem in enumerate (mem_managers ):
211228 for layer_index in range (mem .layer_num ):
212- dist .recv (recive_buffer , src = 0 )
229+ nccl_comm .recv (recive_buffer , src = 0 )
213230 mem ._write_kv_move_data_p2p (move_token_indexes , recive_buffer , layer_index )
214231 return
215232
0 commit comments