@@ -250,18 +250,35 @@ def forward(
250250 if kv_cache is not None and len (kv_cache ) == 2 :
251251 self .latent_cache_k (latent_vec_k , kv_cache [0 ], slot_mapping )
252252 k_cache = kv_cache [0 ]
253+ else :
254+ k_cache = None
253255
254256 if is_prefill :
255- return self ._forward_prefill (q , k_c_normed , k_pe , attn_metadata ,
256- batch_size )
257+ return self ._forward_prefill (q , latent_vec_k , k_cache ,
258+ attn_metadata , batch_size )
257259 else :
258260 return self ._forward_decode (decode_ql_nope , q_pe , k_cache ,
259261 attn_metadata , batch_size )
260262
261263 def _forward_prefill ( # type: ignore
262- self , q : torch .Tensor , k_c_normed : torch .Tensor ,
263- k_pe : torch .Tensor , attn_metadata : HPUAttentionMetadata ,
264+ self , q : torch .Tensor , latent_vec_k : torch .Tensor ,
265+ k_cache : torch .Tensor , attn_metadata : HPUAttentionMetadata ,
264266 batch_size : int ) -> torch .Tensor :
267+ ##### get prefix cache #####
268+ if attn_metadata .block_list is not None :
269+ current = latent_vec_k
270+ past = self .latent_cache_k .fetch_from_cache (
271+ k_cache .unflatten (0 , (- 1 , attn_metadata .block_size )),
272+ attn_metadata .block_list )
273+ past = past .view (- 1 , past .shape [- 1 ])
274+ current = torch .concat ((past , current ), dim = 0 )
275+ latent_vec_k = current
276+ # =========================== #
277+
278+ k_c_normed , k_pe = latent_vec_k .split (
279+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
280+ k_pe = k_pe .view (- 1 , 1 , self .qk_rope_head_dim )
281+
265282 kv_nope = self .kv_b_proj (k_c_normed )[0 ]\
266283 .view (- 1 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim )
267284 k_nope , v = kv_nope \
@@ -290,11 +307,14 @@ def _forward_prefill( # type: ignore
290307 value = v_padded ,
291308 is_causal = True ,
292309 attn_bias = attn_metadata .attn_bias ,
310+ position_bias = None ,
293311 valid_seq_lengths = attn_metadata .seq_lens_tensor ,
294312 scale = self .scale ,
295313 matmul_qk_op = self .matmul_qk ,
296314 softmax_op = self .softmax ,
297315 matmul_av_op = self .matmul_av ,
316+ keys_fetch_func = self .latent_cache_k .fetch_from_cache ,
317+ values_fetch_func = None ,
298318 fsdpa_op = self .fused_scaled_dot_product_attention .apply \
299319 if self .fused_scaled_dot_product_attention is not None else None )
300320 attn_output = out .view (batch_size , - 1 , self .num_heads , q .shape [- 1 ])
0 commit comments