77import torch .nn .functional as F
88from torch .nn import Module , ModuleList
99
10- from colt5_attention import topk as differentiable_topk
11-
1210from local_attention import LocalAttention
1311
1412from rotary_embedding_torch import RotaryEmbedding
@@ -87,7 +85,6 @@ def __init__(
8785 num_compressed_mem_kv = 4 ,
8886 norm = True ,
8987 use_diff_topk = False ,
90- diff_topk_coor_descent_iters = 10.
9188 ):
9289 super ().__init__ ()
9390 self .heads = heads
@@ -142,7 +139,6 @@ def __init__(
142139 # selection related
143140
144141 self .use_diff_topk = use_diff_topk
145- self .diff_topk_coor_descent_iters = diff_topk_coor_descent_iters
146142
147143 self .selection_block_size = selection_block_size
148144 self .num_selected_blocks = num_selected_blocks
@@ -222,12 +218,12 @@ def forward(
222218
223219 # 2. fine attention over selected based on compressed attention logits
224220
225- importance_scores = csim [..., num_mem_compress_kv :]
221+ importance_scores = cattn [..., num_mem_compress_kv :]
222+
223+ selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
226224
227225 if self .use_diff_topk :
228- selected_importance_values , selected_block_indices , _ , gates = differentiable_topk (importance_scores , self .num_selected_blocks , fused = True )
229- else :
230- selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
226+ gates = selected_importance_values + (1. - selected_importance_values ).detach ()
231227
232228 fmask = selected_importance_values > mask_value
233229
@@ -247,6 +243,9 @@ def forward(
247243
248244 selected_block_indices = pad_at_dim (selected_block_indices , (0 , remainder ), value = 0 , dim = - 2 )
249245
246+ if self .use_diff_topk :
247+ gates = pad_at_dim (gates , (0 , remainder ), value = 1. , dim = - 2 )
248+
250249 # handle block causal diagonal in the diagram, but run experiments without to see
251250
252251 fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
@@ -272,7 +271,7 @@ def forward(
272271 # handle maybe gating
273272
274273 if self .use_diff_topk :
275- gates = F .pad (gates , (0 , 1 , 0 , remainder ), value = 1. )
274+ gates = F .pad (gates , (0 , 1 ), value = 1. )
276275
277276 fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
278277 fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
0 commit comments