@@ -398,8 +398,9 @@ def forward(
398398 selected_importance_values , selected_block_indices = importance_scores .topk (num_selected , dim = - 1 )
399399
400400 if self .use_diff_topk :
401- assert not exists (fine_selection_flex_mask )
402401 gates = straight_through (selected_importance_values , 1. )
402+ gates = gates .cumsum (dim = - 1 )[..., - 1 ]
403+ gates = repeat (gates , 'b h ... -> b (h qh) ...' , qh = self .num_grouped_queries )
403404
404405 if exists (fine_selection_flex_mask ):
405406 # flex attention for the selection for fine attention
@@ -422,7 +423,7 @@ def forward(
422423 selected_block_indices = pad_at_dim (selected_block_indices , (0 , remainder ), value = 0 , dim = - 2 )
423424
424425 if self .use_diff_topk :
425- gates = pad_at_dim (gates , (0 , remainder ), value = 1. , dim = - 2 )
426+ gates = pad_at_dim (gates , (0 , remainder ), value = 1. )
426427
427428 # handle block causal diagonal in the diagram, but run experiments without to see
428429
@@ -453,16 +454,7 @@ def forward(
453454 fk = fk .gather (3 , selected_block_indices )
454455 fv = fv .gather (3 , selected_block_indices )
455456
456- # handle maybe gating
457-
458- if self .use_diff_topk :
459- gates = F .pad (gates , (0 , 1 ), value = 1. )
460-
461- fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
462- fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
463-
464- fk = rearrange (fk , 'b h i w j d -> b h i (w j) d' )
465- fv = rearrange (fv , 'b h i w j d -> b h i (w j) d' )
457+ fk , fv = tuple (rearrange (t , 'b h i w j d -> b h i (w j) d' ) for t in (fk , fv ))
466458
467459 # fine attention
468460
@@ -483,6 +475,13 @@ def forward(
483475 fine_attn_out = rearrange (fine_attn_out , 'b h qh ... -> b (h qh) ...' )
484476
485477 fine_attn_out = fine_attn_out [..., :seq_len , :]
478+
479+ # handle maybe gating
480+
481+ if self .use_diff_topk :
482+ gates = gates [..., :seq_len ]
483+ fine_attn_out = einx .multiply ('b h n, b h n d -> b h n d' , gates , fine_attn_out )
484+
486485 else :
487486 # if only first block, just do a simple block causal
488487
0 commit comments