@@ -324,12 +324,19 @@ def forward_inference(
324324 ):
325325 # destruct cache
326326
327- (cache_k , cache_v ), (cache_ck , cache_cv ) = cache
327+ (
328+ (cache_k , cache_v ),
329+ (
330+ (cache_ck , cache_cv ),
331+ (run_k , run_v )
332+ )
333+ ) = cache
328334
329335 # variables
330336
331337 batch , scale , heads , device = inp .shape [0 ], self .scale , self .heads , inp .device
332- seq_len = cache_k .shape [- 2 ] + 1
338+ cache_len = cache_k .shape [- 2 ]
339+ seq_len = cache_len + 1
333340
334341 sliding_window = self .sliding_window_size
335342 compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
@@ -347,7 +354,17 @@ def forward_inference(
347354
348355 q , k , v = map (self .split_heads , (q , k , v ))
349356
350- # handle cache
357+ # take care of running k and v for compression, which should NOT be rotated https://arxiv.org/abs/2501.18795
358+
359+ run_k = cat ((run_k , k ), dim = - 2 )
360+ run_v = cat ((run_v , v ), dim = - 2 )
361+
362+ # rotate after updating the compression running k/v
363+
364+ q = self .rotary_emb .rotate_queries_or_keys (q , offset = cache_len )
365+ k = self .rotary_emb .rotate_queries_or_keys (k , offset = cache_len )
366+
367+ # handle cache, which stores the rotated
351368
352369 k = cat ((cache_k , k ), dim = - 2 )
353370 v = cat ((cache_v , v ), dim = - 2 )
@@ -369,18 +386,24 @@ def forward_inference(
369386
370387 compressed_attn_out = einsum (cattn , repeated_cv , 'b h i j, b h j d -> b h i d' )
371388
372- if divisible_by (seq_len , self .compress_block_size ):
373- k_compress_input = self .split_compress_window (k [..., - self .compress_block_size :, :] + self .k_intrablock_positions )
374- v_compress_input = self .split_compress_window (v [..., - self .compress_block_size :, :] + self .v_intrablock_positions )
389+ running_compress_seq_len = run_k .shape [- 2 ]
390+
391+ if divisible_by (running_compress_seq_len , self .compress_block_size ):
392+
393+ k_compress_input = self .split_compress_window (run_k + self .k_intrablock_positions )
394+ v_compress_input = self .split_compress_window (run_v + self .v_intrablock_positions )
375395
376396 next_ck = self .k_compress (k_compress_input )
377397 next_cv = self .v_compress (v_compress_input )
378398
399+ run_k = run_k [..., 0 :0 , :]
400+ run_v = run_v [..., 0 :0 , :]
401+
379402 ck = cat ((ck , next_ck ), dim = - 2 )
380403 cv = cat ((cv , next_cv ), dim = - 2 )
381404
382405 if return_cache :
383- cache_compressed_kv = (ck , cv )
406+ cache_compressed_kv = (( ck , cv ), ( run_k , run_v ) )
384407
385408 # 2. fine attention inference (todo - compress and fine diff block sizes)
386409
@@ -395,10 +418,8 @@ def forward_inference(
395418
396419 # block causal diagonal
397420
398- rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k )
399-
400421 fine_sliding_window = (seq_len % self .selection_block_size ) + 1
401- fk = rotated_k [..., - fine_sliding_window :, :]
422+ fk = k [..., - fine_sliding_window :, :]
402423 fv = v [..., - fine_sliding_window :, :]
403424
404425 # select out the sparse kv segments as defined by compressed attention map as importance score
@@ -412,7 +433,7 @@ def forward_inference(
412433 fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
413434 remainder = fine_divisible_seq_len - k .shape [- 2 ]
414435
415- sel_fk = pad_at_dim (rotated_k , (0 , remainder ), dim = - 2 )
436+ sel_fk = pad_at_dim (k , (0 , remainder ), dim = - 2 )
416437 sel_fv = pad_at_dim (v , (0 , remainder ), dim = - 2 )
417438
418439 sel_fk = rearrange (sel_fk , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
@@ -438,7 +459,7 @@ def forward_inference(
438459
439460 # remove later
440461
441- fq = rearrange (rotated_q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
462+ fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
442463
443464 fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
444465
@@ -524,12 +545,15 @@ def forward(
524545 k_compress_input = self .split_compress_window (k [..., :compress_divisible_seq_len , :] + k_pos )
525546 v_compress_input = self .split_compress_window (v [..., :compress_divisible_seq_len , :] + v_pos )
526547
548+ run_k = k [..., compress_divisible_seq_len :, :]
549+ run_v = v [..., compress_divisible_seq_len :, :]
550+
527551 cq = q
528552 ck = self .k_compress (k_compress_input ) # Equation (7) of the Native Sparse Attention paper
529553 cv = self .v_compress (v_compress_input )
530554
531555 if return_cache :
532- cache_compressed_kv = (ck , cv )
556+ cache_compressed_kv = (( ck , cv ), ( run_k , run_v ) )
533557
534558 # 1. coarse attention over compressed
535559
@@ -549,7 +573,6 @@ def forward(
549573 compressed_attn_out , csim = attend (cq , ck , cv , mask = cmask , return_sim = True )
550574
551575 # for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
552-
553576 rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k )
554577
555578 # 2. fine attention over selected based on compressed attention logits - variables prepended with `f` stands for the fine attention pathway
0 commit comments