@@ -395,8 +395,10 @@ def forward_inference(
395395
396396 # block causal diagonal
397397
398+ rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k )
399+
398400 fine_sliding_window = (seq_len % self .selection_block_size ) + 1
399- fk = k [..., - fine_sliding_window :, :]
401+ fk = rotated_k [..., - fine_sliding_window :, :]
400402 fv = v [..., - fine_sliding_window :, :]
401403
402404 # select out the sparse kv segments as defined by compressed attention map as importance score
@@ -410,7 +412,7 @@ def forward_inference(
410412 fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
411413 remainder = fine_divisible_seq_len - k .shape [- 2 ]
412414
413- sel_fk = pad_at_dim (k , (0 , remainder ), dim = - 2 )
415+ sel_fk = pad_at_dim (rotated_k , (0 , remainder ), dim = - 2 )
414416 sel_fv = pad_at_dim (v , (0 , remainder ), dim = - 2 )
415417
416418 sel_fk = rearrange (sel_fk , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
@@ -430,7 +432,7 @@ def forward_inference(
430432
431433 # remove later
432434
433- fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
435+ fq = rearrange (rotated_q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
434436
435437 fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
436438
@@ -457,7 +459,7 @@ def forward_inference(
457459
458460 strategy_weighted_combine = self .to_strategy_combine (inp )
459461
460- out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , compressed_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
462+ out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , fine_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
461463
462464 # merge heads and combine them
463465
0 commit comments