@@ -290,6 +290,9 @@ def __init__(
290290 self .split_compress_window = split_compress_window_fn
291291 self .compress_window_size = compress_window_size
292292
293+ assert compress_block_overlap_len < compress_block_size
294+ self .compress_block_overlap_len = compress_block_overlap_len
295+
293296 # compression attention related parameters
294297
295298 self .num_mem_compress_kv = num_compressed_mem_kv
@@ -382,6 +385,7 @@ def forward_inference(
382385
383386 sliding_window = self .sliding_window_size
384387 compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
388+ compress_overlap_len = self .compress_block_overlap_len
385389
386390 fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
387391 num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
@@ -439,19 +443,20 @@ def forward_inference(
439443
440444 running_compress_seq_len = run_k .shape [- 2 ]
441445
442- if divisible_by (running_compress_seq_len , self .compress_block_size ):
443-
444- k_compress_input = self .split_compress_window (run_k )
445- v_compress_input = self .split_compress_window (run_v )
446+ if divisible_by (running_compress_seq_len , self .compress_block_size + compress_overlap_len ):
447+ k_compress_input = rearrange (run_k , 'b h n d -> b h 1 n d' )
448+ v_compress_input = rearrange (run_v , 'b h n d -> b h 1 n d' )
446449
447450 k_compress_input = einx .add ('b h w n d, h n d' , k_compress_input , self .k_intrablock_positions )
448451 v_compress_input = einx .add ('b h w n d, h n d' , v_compress_input , self .v_intrablock_positions )
449452
450453 next_ck = self .k_compress (k_compress_input )
451454 next_cv = self .v_compress (v_compress_input )
452455
453- run_k = run_k [..., 0 :0 , :]
454- run_v = run_v [..., 0 :0 , :]
456+ run_kv_slice = slice (- compress_overlap_len , None ) if compress_overlap_len > 0 else slice (0 , 0 )
457+
458+ run_k = run_k [..., run_kv_slice , :]
459+ run_v = run_v [..., run_kv_slice , :]
455460
456461 ck = cat ((ck , next_ck ), dim = - 2 )
457462 cv = cat ((cv , next_cv ), dim = - 2 )
@@ -593,6 +598,8 @@ def forward(
593598 compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
594599 num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
595600
601+ compress_overlap_len = self .compress_block_overlap_len
602+
596603 fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
597604 num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
598605
@@ -622,8 +629,14 @@ def forward(
622629 k_compress_input = einx .add ('b h w n d, h n d' , k_compress_input , self .k_intrablock_positions )
623630 v_compress_input = einx .add ('b h w n d, h n d' , v_compress_input , self .v_intrablock_positions )
624631
625- run_k = k [..., compress_divisible_seq_len :, :]
626- run_v = v [..., compress_divisible_seq_len :, :]
632+ run_k , run_v = k , v
633+
634+ if return_cache and compress_overlap_len > 0 :
635+ run_k = F .pad (run_k , (0 , 0 , compress_overlap_len , 0 ), value = 0. )
636+ run_v = F .pad (run_v , (0 , 0 , compress_overlap_len , 0 ), value = 0. )
637+
638+ run_k = run_k [..., compress_divisible_seq_len :, :]
639+ run_v = run_v [..., compress_divisible_seq_len :, :]
627640
628641 cq = q
629642 ck = self .k_compress (k_compress_input ) # Equation (7) of the Native Sparse Attention paper
0 commit comments