@@ -392,16 +392,25 @@ def forward(
392392 compress_seq_len = score_len * self .compress_block_size
393393
394394 if self .interpolated_importance_score :
395- mask = importance_scores > 1e-10
396- mask = repeat (mask , '... j -> ... (j block_size)' , block_size = self .compress_block_size )
397395 importance_scores = interpolate_1d (importance_scores , compress_seq_len )
398- importance_scores = importance_scores .masked_fill (~ mask , 0. )
399396 else :
400397 importance_scores = repeat (importance_scores , '... j -> ... (j block_size)' , block_size = self .compress_block_size )
401398
402399 padding = fine_divisible_seq_len - compress_seq_len
400+
401+ fine_query_seq_len = importance_scores .shape [- 2 ]
402+ fine_query_padding = fine_divisible_seq_len - importance_scores .shape [- 2 ]
403+
403404 importance_scores = F .pad (importance_scores , (0 , padding ))
404405
406+ # mask out the diagonal since block causal is included by default for fine attending
407+
408+ block_causal_mask = torch .ones ((num_fine_blocks ,) * 2 , device = device , dtype = torch .bool ).tril (- 1 )
409+ block_causal_mask = repeat (block_causal_mask , 'i j -> (i n1) (j n2)' , n1 = self .selection_block_size , n2 = self .selection_block_size )
410+ block_causal_mask = block_causal_mask [:fine_query_seq_len ]
411+
412+ importance_scores = importance_scores .masked_fill (~ block_causal_mask , 0. )
413+
405414 importance_scores = reduce (importance_scores , '... (j block_size) -> ... j' , 'mean' , block_size = self .selection_block_size )
406415
407416 # handle if number of total blocks is less than number to select for fine attention
@@ -411,6 +420,9 @@ def forward(
411420 fv = v
412421
413422 if has_selected_kv_for_fine_attn :
423+
424+ # get the top-n kv segments for fine attention
425+
414426 selected_importance_values , selected_block_indices = importance_scores .topk (num_selected , dim = - 1 )
415427
416428 if self .use_diff_topk :
0 commit comments