@@ -241,20 +241,22 @@ def send_kv_caches_and_hidden_states_cpu(
241241 hidden_states_list : List [torch .Tensor ],
242242 ) -> None :
243243 start_time = time .time ()
244- if self .rank != 0 :
244+ if self .rank % self . tp_size != 0 :
245245 # only the first rank will send kv cache
246246 return
247247 assert len (input_tokens_list ) == len (kv_caches_send_list )
248248 assert len (input_tokens_list ) == len (hidden_states_list )
249249 for idx , input_tokens in enumerate (input_tokens_list ):
250250 store_key_prefix = self .tensor_hash (input_tokens )
251- store_kvcache_key = f"{ store_key_prefix } _{ self .rank } "
252- store_hidden_key = f"{ store_key_prefix } _hidden_{ self .rank } "
251+ store_kvcache_key = f"{ store_key_prefix } _{ self .rank % self . tp_size } "
252+ store_hidden_key = f"{ store_key_prefix } _hidden_{ self .rank % self . tp_size } "
253253
254254 self .kv_store .put_tensor (store_kvcache_key ,
255255 kv_caches_send_list [idx ])
256256 self .kv_store .put_tensor (store_hidden_key , hidden_states_list [idx ])
257- logger .info ("[rank %d]: KV send DONE. send %d, takes %f s" , self .rank ,
257+ logger .info ("[rank %d][tp size %d]:KV send DONE. send %d, takes %f s" ,
258+ self .rank ,
259+ self .tp_size ,
258260 len (input_tokens_list ),
259261 time .time () - start_time )
260262
@@ -266,7 +268,7 @@ def send_kv_caches_and_hidden_states_hpu(
266268 hidden_or_intermediate_states : Union [torch .Tensor ,
267269 IntermediateTensors ],
268270 ) -> None :
269- if self .rank != 0 :
271+ if self .rank % self . tp_size != 0 :
270272 # only the first rank will send kv cache
271273 return
272274 start_time = time .time ()
@@ -312,12 +314,12 @@ def send_kv_caches_and_hidden_states_hpu(
312314 keys = torch .cat (keys , dim = 0 )
313315 kvcache_to_sent = keys .cpu ()
314316 logger .debug ("kv cache reshape time: %s" , time .time () - start_time )
315- store_kvcache_key = f"{ store_key_prefix } _{ self .rank } "
317+ store_kvcache_key = f"{ store_key_prefix } _{ self .rank % self . tp_size } "
316318 self .kv_store .put_tensor (store_kvcache_key , kvcache_to_sent )
317319
318320 logger .debug ("put kv cache key: %s" , store_kvcache_key )
319321
320- hidden_key = f"{ store_key_prefix } _hidden_{ self .rank } "
322+ hidden_key = f"{ store_key_prefix } _hidden_{ self .rank % self . tp_size } "
321323 self .kv_store .put_tensor (
322324 hidden_key ,
323325 hidden_or_intermediate_states [idx ].unsqueeze (0 ).cpu ())
0 commit comments