33import torch
44import torch .distributed as dist
55from typing import List
6+ from lightllm .server .pd_io_struct import KVMoveTask
67from lightllm .utils .log_utils import init_logger
78from lightllm .server .router .dynamic_prompt .shared_arr import SharedInt
89from lightllm .utils .profile_max_tokens import get_available_gpu_memory , get_total_gpu_memory
@@ -79,25 +80,27 @@ def alloc_kv_move_buffer(self, max_req_total_len):
7980 )
8081 return
8182
82- def send_to_decode_node (
83- self , token_indexes : List [int ], mem_managers : List ["MemoryManager" ], dp_size : int , dp_index : int
84- ):
83+ def send_to_decode_node (self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int ):
8584 """
86- dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
85+ dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
8786 普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
8887 被真正使用
8988 """
9089 assert dp_size == 1
91- assert dp_index == 0
9290
9391 # 先将数据发送到指定的一张卡上的buffer,再发送。
9492 import torch .distributed as dist
9593
94+ move_token_indexes = []
95+ for task in move_tasks :
96+ if task .move_kv_len != 0 :
97+ move_token_indexes .extend (task .prefill_token_indexes [- task .move_kv_len :])
98+
9699 cur_device_index = self .kv_buffer .get_device ()
97100 cur_mem = mem_managers [cur_device_index ]
98101 for i , mem in enumerate (mem_managers ):
99102 for layer_index in range (mem .layer_num ):
100- move_buffer = mem ._get_kv_move_data (token_indexes , layer_index )
103+ move_buffer = mem ._get_kv_move_data (move_token_indexes , layer_index )
101104 if i == cur_device_index :
102105 dist .send (move_buffer , dst = 1 )
103106 else :
@@ -118,34 +121,38 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
118121 return move_buffer
119122
120123 def receive_from_prefill_node (
121- self , token_indexes : List [int ], mem_managers : List ["MemoryManager" ], dp_size : int , dp_index : int
124+ self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size : int
122125 ):
123126 """
124- dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
125- 普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
127+ dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
128+ 普通tp模式下, dp_size 一定等于 1, 同时普通模式下, 这两个参数并不会
126129 被真正使用
127130 """
128131 assert dp_size == 1
129- assert dp_index == 0
130132
131133 # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
132134 import torch .distributed as dist
133135
136+ move_token_indexes = []
137+ for task in move_tasks :
138+ if task .move_kv_len != 0 :
139+ move_token_indexes .extend (task .decode_token_indexes [- task .move_kv_len :])
140+
134141 cur_device_index = self .kv_buffer .get_device ()
135- token_num = len (token_indexes )
142+ token_num = len (move_token_indexes )
136143 move_size = self .kv_buffer .numel () // self .layer_num // self .size * token_num
137144 recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (1 , token_num , 2 * self .head_num , self .head_dim )
138145 for i , mem in enumerate (mem_managers ):
139146 for layer_index in range (mem .layer_num ):
140147 dist .recv (recive_buffer , src = 0 )
141148 if i == cur_device_index :
142- mem ._write_kv_move_data (token_indexes , recive_buffer , layer_index )
149+ mem ._write_kv_move_data (move_token_indexes , recive_buffer , layer_index )
143150 else :
144151 new_recive_buffer = mem .kv_move_buffer .view (- 1 )[0 :move_size ].view (recive_buffer .shape )
145152 from torch .cuda import comm
146153
147154 comm .broadcast (recive_buffer , out = [new_recive_buffer ])
148- mem ._write_kv_move_data (token_indexes , new_recive_buffer , layer_index )
155+ mem ._write_kv_move_data (move_token_indexes , new_recive_buffer , layer_index )
149156 return
150157
151158 def _write_kv_move_data (self , token_indexes : torch .Tensor , buffer_tensor : torch .Tensor , layer_index ):
0 commit comments