@@ -475,15 +475,18 @@ def initialize_kv_hash_cache_tensors(self, kv_caches, device):
475475 kv_caches [layer_name ] = (kv_cache , khash_cache )
476476
477477 def initialize_kv_hash_cache_tensors_npu (self , kv_caches , device ):
478- print (f"initialize_kv_hash_cache_tensors_npu: allocating hashk cache for KVComp in NPU" )
478+ print (f"[NPU KVComp Debug] initialize_kv_hash_cache_tensors_npu: allocating hashk cache for KVComp in NPU" )
479479 for layer_name , kv_cache in kv_caches .items ():
480480 is_rollback_layer , is_skip_hash_layer = self .get_layer_state (layer_name )
481481 k_cache_shape = kv_cache [0 ].shape
482+ print (f"[NPU KVComp Debug] layer_name: { layer_name } , is_rollback_layer={ is_rollback_layer } , is_skip_hash_layer={ is_skip_hash_layer } , k_cache_shape: { k_cache_shape } " )
482483 khash_cache_shape = (k_cache_shape [0 ], k_cache_shape [2 ], k_cache_shape [1 ], self .hash_encoder .hash_bits // 8 )
483484 if not is_rollback_layer and not is_skip_hash_layer :
484485 khash_cache = torch .empty (khash_cache_shape , dtype = torch .uint8 , device = device )
486+ print (f"[NPU KVComp Debug] layer_name: { layer_name } , khash_cache_shape: { khash_cache_shape } " )
485487 else :
486488 khash_cache = None
489+ print (f"[NPU KVComp Debug] layer_name: { layer_name } , khash_cache is None" )
487490 kv_caches [layer_name ] = (kv_cache , khash_cache )
488491
489492 def build_decode_hash (self , seq_lens ):
0 commit comments