@@ -1064,22 +1064,19 @@ def _attention_with_mask_hpu(
10641064 # Skip writing kv-cache for the initial profiling run.
10651065 if kv_cache is not None and isinstance (kv_cache , tuple ):
10661066 assert self .attn .backend == _Backend .HPU_ATTN
1067- # During cross-attention decode, key & value will be None,
1068- # we don't need to cache them.
1069- if (k is not None ) and (v is not None ):
1070- from vllm .attention .ops .hpu_paged_attn import HPUPagedAttention
1071- key_cache , value_cache = HPUPagedAttention .split_kv_cache (
1072- kv_cache , self .num_local_key_value_heads , self .head_dim )
1073- cached_k = torch .cat ([k [s :e ] for s , e in kv_range_for_decode ])
1074- cached_v = torch .cat ([v [s :e ] for s , e in kv_range_for_decode ])
1075- slot_mapping = torch .cat ([
1076- attn_metadata .cross_slot_mapping [s :e ]
1077- for s , e in kv_range_for_decode
1078- ])
1079- key_cache = self .attn .impl .k_cache (cached_k , key_cache ,
1080- slot_mapping )
1081- value_cache = self .attn .impl .v_cache (cached_v , value_cache ,
1082- slot_mapping )
1067+ from vllm .attention .ops .hpu_paged_attn import HPUPagedAttention
1068+ key_cache , value_cache = HPUPagedAttention .split_kv_cache (
1069+ kv_cache , self .num_local_key_value_heads , self .head_dim )
1070+ cached_k = torch .cat ([k [s :e ] for s , e in kv_range_for_decode ])
1071+ cached_v = torch .cat ([v [s :e ] for s , e in kv_range_for_decode ])
1072+ slot_mapping = torch .cat ([
1073+ attn_metadata .cross_slot_mapping [s :e ]
1074+ for s , e in kv_range_for_decode
1075+ ])
1076+ key_cache = self .attn .impl .k_cache (cached_k , key_cache ,
1077+ slot_mapping )
1078+ value_cache = self .attn .impl .v_cache (cached_v , value_cache ,
1079+ slot_mapping )
10831080
10841081 q_len = q .shape [0 ]
10851082 kv_len = k .shape [0 ]
0 commit comments