@@ -67,9 +67,9 @@ def compress_mask(_, __, q_idx, kv_idx):
6767
6868def create_fine_mask (seq_len , fine_block_size ):
6969
70- def inner (selected_block_indices : Tensor ):
70+ def inner (selected_block_indices : Tensor , num_grouped_queries = 1 ):
7171 device = selected_block_indices .device
72- batch , heads = selected_block_indices .shape [:2 ]
72+ batch , kv_heads = selected_block_indices .shape [:2 ]
7373
7474 one_hot_selected_block_indices = torch .zeros ((* selected_block_indices .shape [:- 1 ], seq_len // fine_block_size ), device = device , dtype = torch .bool )
7575 one_hot_selected_block_indices .scatter_ (- 1 , selected_block_indices , True )
@@ -78,15 +78,16 @@ def fine_mask(b_idx, h_idx, q_idx, kv_idx):
7878
7979 compressed_q_idx = q_idx // fine_block_size
8080 compressed_kv_idx = kv_idx // fine_block_size
81+ kv_head_idx = h_idx // num_grouped_queries
8182
82- is_selected = one_hot_selected_block_indices [b_idx , h_idx , q_idx , compressed_kv_idx ]
83+ is_selected = one_hot_selected_block_indices [b_idx , kv_head_idx , q_idx , compressed_kv_idx ]
8384
8485 causal_mask = q_idx >= kv_idx
8586 block_diagonal = compressed_q_idx == compressed_kv_idx
8687
8788 return (causal_mask & (block_diagonal | is_selected ))
8889
89- block_mask = create_block_mask (fine_mask , B = batch , H = heads , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
90+ block_mask = create_block_mask (fine_mask , B = batch , H = kv_heads * num_grouped_queries , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
9091 return block_mask
9192
9293 return inner
@@ -349,11 +350,9 @@ def forward(
349350 if exists (fine_selection_flex_mask ):
350351 # flex attention for the selection for fine attention
351352
352- fk , fv , selected_block_indices = tuple ( repeat ( t , 'b h ... -> b (h num_grouped_queries) ...' , num_grouped_queries = self .num_grouped_queries ) for t in ( fk , fv , selected_block_indices ) )
353+ fine_block_mask = fine_selection_flex_mask ( selected_block_indices , num_grouped_queries = self .num_grouped_queries )
353354
354- fine_block_mask = fine_selection_flex_mask (selected_block_indices )
355-
356- fine_attn_out = flex_attention (fq , fk , fv , block_mask = fine_block_mask )
355+ fine_attn_out = flex_attention (fq , fk , fv , block_mask = fine_block_mask , enable_gqa = True )
357356
358357 else :
359358 fmask = selected_importance_values > 1e-10
0 commit comments