@@ -243,83 +243,97 @@ def forward(
243243
244244 importance_scores = cattn [..., num_mem_compress_kv :]
245245
246- topk = min (self .num_selected_blocks , importance_scores .shape [- 1 ])
247-
248- selected_importance_values , selected_block_indices = importance_scores .topk (topk , dim = - 1 )
249-
250- if self .use_diff_topk :
251- gates = selected_importance_values + (1. - selected_importance_values ).detach ()
252-
253- fmask = selected_importance_values > 1e-10
246+ num_selected = min (self .num_selected_blocks , importance_scores .shape [- 1 ])
254247
255248 fq = rotated_q
256249 fk = rotated_k
257250 fv = v
258251
259- if seq_len < fine_divisible_seq_len :
260- remainder = fine_divisible_seq_len - seq_len
261- fk = pad_at_dim ( fk , ( 0 , remainder ), value = 0. , dim = - 2 )
262- fv = pad_at_dim ( fv , ( 0 , remainder ), value = 0. , dim = - 2 )
263- fq = pad_at_dim ( fq , ( 0 , remainder ), value = 0. , dim = - 2 )
252+ if num_selected > 0 :
253+ selected_importance_values , selected_block_indices = importance_scores . topk ( num_selected , dim = - 1 )
254+
255+ if self . use_diff_topk :
256+ gates = selected_importance_values + ( 1. - selected_importance_values ). detach ( )
264257
265- fmask = pad_at_dim ( fmask , ( 0 , remainder ), value = False , dim = - 2 )
258+ fmask = selected_importance_values > 1e-10
266259
267- selected_block_indices = pad_at_dim (selected_block_indices , (0 , remainder ), value = 0 , dim = - 2 )
260+ if seq_len < fine_divisible_seq_len :
261+ remainder = fine_divisible_seq_len - seq_len
262+ fk = pad_at_dim (fk , (0 , remainder ), value = 0. , dim = - 2 )
263+ fv = pad_at_dim (fv , (0 , remainder ), value = 0. , dim = - 2 )
264+ fq = pad_at_dim (fq , (0 , remainder ), value = 0. , dim = - 2 )
268265
269- if self .use_diff_topk :
270- gates = pad_at_dim (gates , (0 , remainder ), value = 1. , dim = - 2 )
266+ fmask = pad_at_dim (fmask , (0 , remainder ), value = False , dim = - 2 )
271267
272- # handle block causal diagonal in the diagram, but run experiments without to see
268+ selected_block_indices = pad_at_dim ( selected_block_indices , ( 0 , remainder ), value = 0 , dim = - 2 )
273269
274- fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
275- fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = heads )
276- selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
270+ if self .use_diff_topk :
271+ gates = pad_at_dim (gates , (0 , remainder ), value = 1. , dim = - 2 )
277272
278- fmask = repeat ( fmask , 'b h i w -> b h i w j' , j = self . selection_block_size )
273+ # handle block causal diagonal in the diagram, but run experiments without to see
279274
280- causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
281- causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = heads )
275+ fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
276+ fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = heads )
277+ selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
282278
283- fmask = cat ((fmask , causal_mask ), dim = - 2 )
284- fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
279+ fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
285280
286- # select out the spatial crops of keys / values for fine attention
281+ causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
282+ causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = heads )
287283
288- fk = rearrange ( fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
289- fv = rearrange (fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
284+ fmask = cat (( fmask , causal_mask ), dim = - 2 )
285+ fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
290286
291- # get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
287+ # select out the spatial crops of keys / values for fine attention
292288
293- fk = repeat (fk , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ] )
294- fv = repeat (fv , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ] )
289+ fk = rearrange (fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
290+ fv = rearrange (fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
295291
296- selected_block_indices = repeat ( selected_block_indices , ' b h i sel -> b h i sel j d' , j = fk . shape [ - 2 ], d = fk . shape [ - 1 ] )
292+ # get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices )
297293
298- fk = fk . gather ( 3 , selected_block_indices )
299- fv = fv . gather ( 3 , selected_block_indices )
294+ fk = repeat ( fk , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ] )
295+ fv = repeat ( fv , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ] )
300296
301- # handle maybe gating
297+ selected_block_indices = repeat ( selected_block_indices , 'b h i sel -> b h i sel j d' , j = fk . shape [ - 2 ], d = fk . shape [ - 1 ])
302298
303- if self . use_diff_topk :
304- gates = F . pad ( gates , ( 0 , 1 ), value = 1. )
299+ fk = fk . gather ( 3 , selected_block_indices )
300+ fv = fv . gather ( 3 , selected_block_indices )
305301
306- fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
307- fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
302+ # handle maybe gating
303+
304+ if self .use_diff_topk :
305+ gates = F .pad (gates , (0 , 1 ), value = 1. )
308306
309- fk = rearrange ( fk , ' b h i w j d -> b h i ( w j) d' )
310- fv = rearrange ( fv , ' b h i w j d -> b h i ( w j) d' )
307+ fk = einx . multiply ( 'b h i w, b h i w j d -> b h i w j d', gates , fk )
308+ fv = einx . multiply ( 'b h i w, b h i w j d -> b h i w j d', gates , fv )
311309
312- # fine attention
310+ fk = rearrange (fk , 'b h i w j d -> b h i (w j) d' )
311+ fv = rearrange (fv , 'b h i w j d -> b h i (w j) d' )
312+
313+ # fine attention
314+
315+ fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
316+
317+ fsim = fsim .masked_fill (~ fmask , mask_value )
318+
319+ fattn = fsim .softmax (dim = - 1 )
320+
321+ fine_attn_out = einsum (fattn , fv , 'b h i j, b h i j d -> b h i d' )
322+
323+ fine_attn_out = fine_attn_out [..., :seq_len , :]
324+ else :
325+ # if only first block, just do a simple block causal
313326
314- fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
327+ seq_len = fk .shape [- 2 ]
328+ fmask = causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).tril ()
315329
316- fsim = fsim . masked_fill ( ~ fmask , mask_value )
330+ fsim = einsum ( fq , fk , 'b h i d, b h j d -> b h i j' ) * self . scale
317331
318- fattn = fsim .softmax ( dim = - 1 )
332+ fsim = fsim .masked_fill ( ~ fmask , mask_value )
319333
320- fine_attn_out = einsum ( fattn , fv , 'b h i j, b h i j d -> b h i d' )
334+ fattn = fsim . softmax ( dim = - 1 )
321335
322- fine_attn_out = fine_attn_out [..., : seq_len , :]
336+ fine_attn_out = einsum ( fattn , fv , 'b h i j, b h j d -> b h i d' )
323337
324338 # 3. overlapping sliding window, this is unsurprising and expected
325339
0 commit comments