77import torch .nn .functional as F
88from torch .nn import Module , ModuleList
99
10+ from colt5_attention import topk as differentiable_topk
11+
1012from local_attention import LocalAttention
1113
1214from rotary_embedding_torch import RotaryEmbedding
@@ -84,8 +86,11 @@ def __init__(
8486 num_selected_blocks ,
8587 num_compressed_mem_kv = 4 ,
8688 norm = True ,
89+ use_diff_topk = False ,
90+ diff_topk_coor_descent_iters = 10.
8791 ):
8892 super ().__init__ ()
93+ self .heads = heads
8994 self .scale = dim_head ** - 0.5
9095
9196 assert compress_block_size == selection_block_size , 'start off with compressed being equal to selection block sizes'
@@ -136,6 +141,9 @@ def __init__(
136141
137142 # selection related
138143
144+ self .use_diff_topk = use_diff_topk
145+ self .diff_topk_coor_descent_iters = diff_topk_coor_descent_iters
146+
139147 self .selection_block_size = selection_block_size
140148 self .num_selected_blocks = num_selected_blocks
141149
@@ -160,7 +168,7 @@ def forward(
160168 self ,
161169 inp
162170 ):
163- batch , seq_len , scale , device = * inp .shape [:2 ], self .scale , inp .device
171+ batch , seq_len , scale , heads , device = * inp .shape [:2 ], self .scale , self . heads , inp .device
164172
165173 compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
166174 num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
@@ -216,7 +224,10 @@ def forward(
216224
217225 importance_scores = csim [..., num_mem_compress_kv :]
218226
219- selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
227+ if self .use_diff_topk :
228+ selected_importance_values , selected_block_indices , _ , gates = differentiable_topk (importance_scores , self .num_selected_blocks , fused = True )
229+ else :
230+ selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
220231
221232 fmask = selected_importance_values > mask_value
222233
@@ -239,13 +250,13 @@ def forward(
239250 # handle block causal diagonal in the diagram, but run experiments without to see
240251
241252 fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
242- fine_window_seq = rearrange (fine_window_seq , 'n -> n 1' ). expand_as ( selected_block_indices )
253+ fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = heads )
243254 selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
244255
245256 fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
246257
247258 causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
248- causal_mask = repeat (causal_mask , 'i j -> (w i) 1 j' , w = num_fine_blocks ). expand_as ( fmask )
259+ causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = heads )
249260
250261 fmask = cat ((fmask , causal_mask ), dim = - 2 )
251262 fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
@@ -255,8 +266,19 @@ def forward(
255266 fk = rearrange (fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
256267 fv = rearrange (fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
257268
258- fk = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fk , selected_block_indices )
259- fv = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fv , selected_block_indices )
269+ fk = einx .get_at ('b h [w] j d, b h i selected -> b h i selected j d' , fk , selected_block_indices )
270+ fv = einx .get_at ('b h [w] j d, b h i selected -> b h i selected j d' , fv , selected_block_indices )
271+
272+ # handle maybe gating
273+
274+ if self .use_diff_topk :
275+ gates = F .pad (gates , (0 , 1 , 0 , remainder ), value = 1. )
276+
277+ fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
278+ fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
279+
280+ fk = rearrange (fk , 'b h i w j d -> b h i (w j) d' )
281+ fv = rearrange (fv , 'b h i w j d -> b h i (w j) d' )
260282
261283 # fine attention
262284
0 commit comments