@@ -105,9 +105,7 @@ def send_to_decode_node(
105105 nccl_comm : PyNcclCommunicator ,
106106 ):
107107 if dp_size_in_node > 1 :
108- return self .send_to_decode_node_p2p (
109- move_tasks , mem_managers , dp_size_in_node , nccl_comm
110- )
108+ return self .send_to_decode_node_p2p (move_tasks , mem_managers , dp_size_in_node , nccl_comm )
111109
112110 # 先将数据发送到指定的一张卡上的buffer,再发送。
113111
@@ -148,9 +146,7 @@ def receive_from_prefill_node(
148146 nccl_comm : PyNcclCommunicator ,
149147 ):
150148 if dp_size_in_node > 1 :
151- return self .receive_from_prefill_node_p2p (
152- move_tasks , mem_managers , dp_size_in_node , nccl_comm
153- )
149+ return self .receive_from_prefill_node_p2p (move_tasks , mem_managers , dp_size_in_node , nccl_comm )
154150 # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
155151
156152 move_token_indexes = []
@@ -222,7 +218,7 @@ def send_to_decode_node_p2p(
222218 self .kv_move_buffer ,
223219 token_dp_indexes = token_dp_tensor ,
224220 dp_size_in_node = dp_size_in_node ,
225- mem_ptrs_dict = mem_ptrs_dict
221+ mem_ptrs_dict = mem_ptrs_dict ,
226222 )
227223 nccl_comm .send (move_buffer , dst = 1 )
228224 return
@@ -234,7 +230,7 @@ def _get_kv_move_data_p2p(
234230 kv_move_buffer : torch .Tensor ,
235231 token_dp_indexes : Optional [torch .Tensor ] = None ,
236232 dp_size_in_node : int = 1 ,
237- mem_ptrs_dict : Optional [dict ] = None
233+ mem_ptrs_dict : Optional [dict ] = None ,
238234 ):
239235 move_token_num = len (token_indexes )
240236 move_size = self .token_dim_size * move_token_num
@@ -301,7 +297,7 @@ def receive_from_prefill_node_p2p(
301297 layer_index ,
302298 token_dp_indexes = token_dp_tensor ,
303299 dp_size_in_node = dp_size_in_node ,
304- mem_ptrs_dict = mem_ptrs_dict
300+ mem_ptrs_dict = mem_ptrs_dict ,
305301 )
306302 return
307303
@@ -312,7 +308,7 @@ def _write_kv_move_data_p2p(
312308 layer_index : int ,
313309 token_dp_indexes : Optional [torch .Tensor ] = None ,
314310 dp_size_in_node : int = 1 ,
315- mem_ptrs_dict : Optional [dict ] = None
311+ mem_ptrs_dict : Optional [dict ] = None ,
316312 ):
317313 move_token_num = len (token_indexes )
318314 if dp_size_in_node == 1 or token_dp_indexes is None :
0 commit comments