@@ -58,7 +58,7 @@ def sliding_mask(_, __, q_idx, kv_idx):
5858 block_mask = create_block_mask (sliding_mask , B = None , H = None , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
5959 return block_mask
6060
61- def create_compress_mask (seq_len , kv_seq_len , compress_block_size , mem_kv_len = 0 , causal = True ):
61+ def create_compress_mask (seq_len , kv_seq_len , compress_block_sliding_stride , mem_kv_len = 0 , causal = True ):
6262
6363 if not causal :
6464 return None
@@ -70,7 +70,7 @@ def compress_mask(_, __, q_idx, kv_idx):
7070 is_mem_kv = kv_idx < mem_kv_len
7171
7272 kv_without_mem = kv_idx - mem_kv_len
73- compress_kv_idx = (kv_without_mem * compress_block_size ) + (compress_block_size - 1 )
73+ compress_kv_idx = (kv_without_mem * compress_block_sliding_stride ) + (compress_block_sliding_stride - 1 )
7474
7575 causal_mask = q_idx > compress_kv_idx
7676 return causal_mask | is_mem_kv
@@ -193,9 +193,9 @@ def __init__(
193193 heads ,
194194 sliding_window_size ,
195195 compress_block_size ,
196+ compress_block_sliding_stride ,
196197 selection_block_size ,
197198 num_selected_blocks ,
198- compress_block_overlap_len = 0 , # the amount of overlap of a given compression block to the previous block
199199 kv_heads = None ,
200200 num_compressed_mem_kv = 1 ,
201201 causal = False ,
@@ -261,40 +261,28 @@ def __init__(
261261 # compress strategy
262262
263263 self .compress_block_size = compress_block_size
264+ self .compress_block_sliding_stride = compress_block_sliding_stride
265+ assert self .compress_block_size >= self .compress_block_sliding_stride , 'compress_block_size must be >= compress_block_sliding_stride'
266+ assert self .compress_block_sliding_stride > 0 , 'compress_block_sliding_stride must be greater than 0'
267+ assert divisible_by (selection_block_size , self .compress_block_sliding_stride ), f'selection_block_size { selection_block_size } must be divisible by compress_block_sliding_stride { self .compress_block_sliding_stride } '
268+
269+ # Compression window splitting
270+ self .split_compress_window = nn .Sequential (
271+ Rearrange ('b h n d -> (b h) d 1 n' ),
272+ nn .ZeroPad2d (((compress_block_size - compress_block_sliding_stride ), 0 , 0 , 0 )),
273+ nn .Unfold (kernel_size = (1 , self .compress_block_size ), stride = (1 , self .compress_block_sliding_stride )),
274+ Rearrange ('(b h) (d n) w -> b h w n d' , d = dim_head , h = kv_heads , n = self .compress_block_size )
275+ )
264276
265277 assert num_compressed_mem_kv > 0
266-
267- # the function for splitting out the compression windows for the mlp
268-
269- compress_block_has_overlap = compress_block_overlap_len > 0
270- compress_window_size = compress_block_size + compress_block_overlap_len
271-
272- if not compress_block_has_overlap :
273- split_compress_window_fn = Rearrange ('b h (w n) d -> b h w n d' , n = compress_block_size )
274- else :
275- split_compress_window_fn = nn .Sequential (
276- Rearrange ('b h n d -> (b h) d 1 n' ),
277- nn .ZeroPad2d ((compress_block_overlap_len , 0 , 0 , 0 )),
278- nn .Unfold (kernel_size = (1 , compress_window_size ), stride = (1 , compress_block_size )),
279- Rearrange ('(b h) (d n) w -> b h w n d' , d = dim_head , h = kv_heads )
280- )
281-
282- self .split_compress_window = split_compress_window_fn
283- self .compress_window_size = compress_window_size
284-
285- assert compress_block_overlap_len <= compress_block_size
286- self .compress_block_overlap_len = compress_block_overlap_len
287-
288- # compression attention related parameters
289-
290278 self .num_mem_compress_kv = num_compressed_mem_kv
291279 self .compress_mem_kv = nn .Parameter (torch .zeros (2 , kv_heads , num_compressed_mem_kv , dim_head ))
292280
293- self .k_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , compress_window_size , dim_head ))
294- self .v_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , compress_window_size , dim_head ))
281+ self .k_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , self . compress_block_size , dim_head ))
282+ self .v_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , self . compress_block_size , dim_head ))
295283
296284 if not exists (compress_mlp ):
297- compress_dim = compress_window_size * dim_head
285+ compress_dim = self . compress_block_size * dim_head
298286 compress_mlp_dim_hidden = int (compress_mlp_expand_factor * compress_dim )
299287
300288 compress_mlp = nn .Sequential (
@@ -310,11 +298,7 @@ def __init__(
310298 # selection related
311299
312300 self .use_diff_topk = use_diff_topk
313-
314301 self .query_heads_share_selected_kv = query_heads_share_selected_kv
315-
316- assert divisible_by (selection_block_size , compress_block_size ), f'selection block size { selection_block_size } must be greater than or equal to compress block size { compress_block_size } , as well as divisible by the compress block size'
317-
318302 self .selection_block_size = selection_block_size
319303
320304 assert num_selected_blocks >= 0
@@ -376,8 +360,6 @@ def forward_inference(
376360 seq_len = cache_len + 1
377361
378362 sliding_window = self .sliding_window_size
379- compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
380- compress_overlap_len = self .compress_block_overlap_len
381363
382364 fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
383365 num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
@@ -435,7 +417,7 @@ def forward_inference(
435417
436418 running_compress_seq_len = run_k .shape [- 2 ]
437419
438- if divisible_by (running_compress_seq_len , self .compress_block_size + compress_overlap_len ):
420+ if divisible_by (running_compress_seq_len , self .compress_block_size ):
439421 k_compress_input = rearrange (run_k , 'b h n d -> b h 1 n d' )
440422 v_compress_input = rearrange (run_v , 'b h n d -> b h 1 n d' )
441423
@@ -445,6 +427,7 @@ def forward_inference(
445427 next_ck = self .k_compress (k_compress_input )
446428 next_cv = self .v_compress (v_compress_input )
447429
430+ compress_overlap_len = self .compress_block_size - self .compress_block_sliding_stride
448431 run_kv_slice = slice (- compress_overlap_len , None ) if compress_overlap_len > 0 else slice (0 , 0 )
449432
450433 run_k = run_k [..., run_kv_slice , :]
@@ -461,9 +444,9 @@ def forward_inference(
461444 importance_scores = csim [..., self .num_mem_compress_kv :]
462445
463446 num_compress_blocks = importance_scores .shape [- 1 ]
464- num_compress_per_fine = self .selection_block_size // self .compress_block_size
447+ num_compress_per_fine = self .selection_block_size // self .compress_block_sliding_stride
465448
466- if self .compress_block_size != self .selection_block_size :
449+ if self .compress_block_sliding_stride != self .selection_block_size :
467450 compress_seq_len = round_down_mult (num_compress_blocks , num_compress_per_fine )
468451 importance_scores = importance_scores [..., :compress_seq_len ]
469452 importance_scores = reduce (importance_scores , '... (j num_compress_per_fine) -> ... j' , 'mean' , num_compress_per_fine = num_compress_per_fine )
@@ -582,10 +565,10 @@ def forward(
582565
583566 batch , seq_len , scale , heads , kv_heads , device = * inp .shape [:2 ], self .scale , self .heads , self .kv_heads , inp .device
584567
585- compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
586- num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
568+ compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_sliding_stride )
569+ num_compress_blocks = compress_divisible_seq_len // self .compress_block_sliding_stride
587570
588- compress_overlap_len = self .compress_block_overlap_len
571+ compress_overlap_len = self .compress_block_size - self . compress_block_sliding_stride
589572 has_compress_overlap = compress_overlap_len > 0
590573
591574 fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
@@ -609,7 +592,7 @@ def forward(
609592 k_compress_input = self .split_compress_window (k_compress_input )
610593 v_compress_input = self .split_compress_window (v_compress_input )
611594 else :
612- k_compress_input , v_compress_input = tuple (t .reshape (batch , kv_heads , 0 , self .compress_window_size , self .dim_head ) for t in (k_compress_input , v_compress_input ))
595+ k_compress_input , v_compress_input = tuple (t .reshape (batch , kv_heads , 0 , self .compress_block_size , self .dim_head ) for t in (k_compress_input , v_compress_input ))
613596
614597 # add the intra block positions
615598
@@ -645,10 +628,10 @@ def forward(
645628 # compressed masking
646629
647630 cmask = None
648-
631+ # TODO
649632 if self .causal :
650633 cq_seq = arange (seq_len , device = device )
651- ck_seq = ((arange (num_compress_blocks , device = device ) + 1 ) * self .compress_block_size ) - 1
634+ ck_seq = ((arange (num_compress_blocks , device = device ) + 1 ) * self .compress_block_sliding_stride ) - 1
652635 ck_seq = F .pad (ck_seq , (num_mem_compress_kv , 0 ), value = - 1 )
653636
654637 cmask = einx .less ('j, i -> i j' , ck_seq , cq_seq )
@@ -686,9 +669,9 @@ def forward(
686669
687670 if has_selected_kv_for_fine_attn :
688671
689- if self .compress_block_size != self .selection_block_size :
672+ if self .compress_block_sliding_stride != self .selection_block_size :
690673
691- num_compress_per_fine = self .selection_block_size // self .compress_block_size
674+ num_compress_per_fine = self .selection_block_size // self .compress_block_sliding_stride
692675
693676 round_down_score_len = round_down_mult (importance_scores .shape [- 1 ], num_compress_per_fine )
694677 importance_scores = importance_scores [..., :round_down_score_len ]
@@ -729,10 +712,7 @@ def forward(
729712
730713 selected_importance_values , selected_block_indices = importance_scores .topk (num_selected , dim = - 1 )
731714
732- gates = None
733-
734- if self .use_diff_topk :
735- gates = straight_through (selected_importance_values , 1. )
715+ gates = straight_through (selected_importance_values , 1. ) if self .use_diff_topk else None
736716
737717 if self .use_triton_kernel and not disable_triton_kernel :
738718
@@ -762,9 +742,7 @@ def forward(
762742 fmask = selected_importance_values > 1e-10
763743
764744 if seq_len < fine_divisible_seq_len :
765- fk = pad_to_multiple (fk )
766- fv = pad_to_multiple (fv )
767- fq = pad_to_multiple (fq )
745+ fk , fv , fq = map (pad_to_multiple , (fk , fv , fq ))
768746
769747 fmask = pad_at_dim (fmask , (0 , remainder ), value = False , dim = - 2 )
770748
@@ -846,9 +824,7 @@ def forward(
846824 seq_len = fk .shape [- 2 ]
847825 fmask = None
848826
849- fk = pad_to_multiple (fk )
850- fv = pad_to_multiple (fv )
851- fq = pad_to_multiple (fq )
827+ fk , fv , fq = map (pad_to_multiple , (fk , fv , fq ))
852828
853829 fq , fk , fv = tuple (rearrange (t , 'b h (w n) d -> (b w) h n d' , n = self .selection_block_size ) for t in (fq , fk , fv ))
854830
0 commit comments