@@ -89,6 +89,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8989 (1 , max_req_total_len + 8 , 2 * self .head_num , self .head_dim ), dtype = self .dtype , device = "cuda"
9090 )
9191 self .kv_move_buf_indexes = torch .arange (0 , max_req_total_len + 8 , dtype = torch .int64 , device = "cuda" )
92+ self .token_dim_size = self .kv_move_buffer .shape [- 2 ] * self .kv_move_buffer .shape [- 1 ]
9293 return
9394
9495 def send_to_decode_node (
@@ -124,7 +125,7 @@ def send_to_decode_node(
124125 return
125126
126127 def _get_kv_move_data (self , token_indexes : List [int ], layer_index : int ):
127- move_size = self .kv_buffer . numel () // self . layer_num // self . size * len (token_indexes )
128+ move_size = self .token_dim_size * len (token_indexes )
128129 move_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (
129130 1 , len (token_indexes ), 2 * self .head_num , self .head_dim
130131 )
@@ -149,7 +150,7 @@ def receive_from_prefill_node(
149150
150151 cur_device_index = self .kv_buffer .get_device ()
151152 token_num = len (move_token_indexes )
152- move_size = self .kv_buffer . numel () // self . layer_num // self . size * token_num
153+ move_size = self .token_dim_size * token_num
153154 recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (1 , token_num , 2 * self .head_num , self .head_dim )
154155 for i , mem in enumerate (mem_managers ):
155156 for layer_index in range (mem .layer_num ):
@@ -196,7 +197,7 @@ def send_to_decode_node_p2p(
196197
197198 def _get_kv_move_data_p2p (self , token_indexes : torch .Tensor , layer_index : int , kv_move_buffer : torch .Tensor ):
198199 move_token_num = len (token_indexes )
199- move_size = self .kv_buffer . numel () // self . layer_num // self . size * move_token_num
200+ move_size = self .token_dim_size * move_token_num
200201 move_buffer = kv_move_buffer .view (- 1 )[0 :move_size ].view (move_token_num , 2 * self .head_num , self .head_dim )
201202 kv_trans (
202203 self .kv_buffer [layer_index , :, :, :], token_indexes , move_buffer , self .kv_move_buf_indexes [0 :move_token_num ]
@@ -222,7 +223,7 @@ def receive_from_prefill_node_p2p(
222223 move_token_indexes = torch .tensor (move_token_indexes , dtype = torch .int64 , device = "cuda" )
223224
224225 token_num = len (move_token_indexes )
225- move_size = self .kv_buffer . numel () // self . layer_num // self . size * token_num
226+ move_size = self .token_dim_size * token_num
226227 recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (token_num , 2 * self .head_num , self .head_dim )
227228 for i , mem in enumerate (mem_managers ):
228229 for layer_index in range (mem .layer_num ):
0 commit comments