@@ -79,13 +79,12 @@ def fine_mask(b_idx, h_idx, q_idx, kv_idx):
7979 compressed_q_idx = q_idx // fine_block_size
8080 compressed_kv_idx = kv_idx // fine_block_size
8181
82- block_causal_mask = compressed_q_idx > compressed_kv_idx
8382 is_selected = one_hot_selected_block_indices [b_idx , h_idx , q_idx , compressed_kv_idx ]
8483
8584 causal_mask = q_idx >= kv_idx
8685 block_diagonal = compressed_q_idx == compressed_kv_idx
8786
88- return (causal_mask & block_diagonal ) | ( block_causal_mask & is_selected )
87+ return (causal_mask & ( block_diagonal | is_selected ) )
8988
9089 block_mask = create_block_mask (fine_mask , B = batch , H = heads , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
9190 return block_mask
@@ -344,76 +343,87 @@ def forward(
344343 selected_importance_values , selected_block_indices = importance_scores .topk (num_selected , dim = - 1 )
345344
346345 if self .use_diff_topk :
346+ assert not exists (fine_selection_flex_mask )
347347 gates = straight_through (selected_importance_values , 1. )
348348
349- fmask = selected_importance_values > 1e-10
349+ if exists (fine_selection_flex_mask ):
350+ # flex attention for the selection for fine attention
350351
351- if seq_len < fine_divisible_seq_len :
352- remainder = fine_divisible_seq_len - seq_len
353- fk = pad_at_dim (fk , (0 , remainder ), value = 0. , dim = - 2 )
354- fv = pad_at_dim (fv , (0 , remainder ), value = 0. , dim = - 2 )
355- fq = pad_at_dim (fq , (0 , remainder ), value = 0. , dim = - 2 )
352+ fk , fv , selected_block_indices = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , selected_block_indices ))
356353
357- fmask = pad_at_dim ( fmask , ( 0 , remainder ), value = False , dim = - 2 )
354+ fine_block_mask = fine_selection_flex_mask ( selected_block_indices )
358355
359- selected_block_indices = pad_at_dim ( selected_block_indices , ( 0 , remainder ), value = 0 , dim = - 2 )
356+ fine_attn_out = flex_attention ( fq , fk , fv , block_mask = fine_block_mask )
360357
361- if self . use_diff_topk :
362- gates = pad_at_dim ( gates , ( 0 , remainder ), value = 1. , dim = - 2 )
358+ else :
359+ fmask = selected_importance_values > 1e-10
363360
364- # handle block causal diagonal in the diagram, but run experiments without to see
361+ if seq_len < fine_divisible_seq_len :
362+ remainder = fine_divisible_seq_len - seq_len
363+ fk = pad_at_dim (fk , (0 , remainder ), value = 0. , dim = - 2 )
364+ fv = pad_at_dim (fv , (0 , remainder ), value = 0. , dim = - 2 )
365+ fq = pad_at_dim (fq , (0 , remainder ), value = 0. , dim = - 2 )
365366
366- fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
367- fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = self .kv_heads )
368- selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
367+ fmask = pad_at_dim (fmask , (0 , remainder ), value = False , dim = - 2 )
369368
370- fmask = repeat ( fmask , 'b h i w -> b h i w j' , j = self . selection_block_size )
369+ selected_block_indices = pad_at_dim ( selected_block_indices , ( 0 , remainder ), value = 0 , dim = - 2 )
371370
372- causal_mask = torch . ones (( self . selection_block_size ,) * 2 , device = device , dtype = torch . bool ). tril ()
373- causal_mask = repeat ( causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = self . kv_heads )
371+ if self . use_diff_topk :
372+ gates = pad_at_dim ( gates , ( 0 , remainder ), value = 1. , dim = - 2 )
374373
375- fmask = cat ((fmask , causal_mask ), dim = - 2 )
376- fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
374+ # handle block causal diagonal in the diagram, but run experiments without to see
377375
378- # select out the spatial crops of keys / values for fine attention
376+ fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
377+ fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = self .kv_heads )
378+ selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
379379
380- fk = rearrange (fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
381- fv = rearrange (fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
380+ fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
382381
383- # get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
382+ causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
383+ causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = self .kv_heads )
384384
385- fk = repeat ( fk , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ] )
386- fv = repeat ( fv , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ] )
385+ fmask = cat (( fmask , causal_mask ), dim = - 2 )
386+ fmask = rearrange ( fmask , 'b h i w j -> b h i ( w j)' )
387387
388- selected_block_indices = repeat ( selected_block_indices , 'b h i sel -> b h i sel j d' , j = fk . shape [ - 2 ], d = fk . shape [ - 1 ])
388+ # select out the spatial crops of keys / values for fine attention
389389
390- fk = fk . gather ( 3 , selected_block_indices )
391- fv = fv . gather ( 3 , selected_block_indices )
390+ fk = rearrange ( fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
391+ fv = rearrange ( fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
392392
393- # handle maybe gating
393+ # get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
394394
395- if self . use_diff_topk :
396- gates = F . pad ( gates , ( 0 , 1 ), value = 1. )
395+ fk = repeat ( fk , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ])
396+ fv = repeat ( fv , 'b h w j d -> b h i w j d' , i = selected_block_indices . shape [ 2 ] )
397397
398- fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
399- fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
398+ selected_block_indices = repeat (selected_block_indices , 'b h i sel -> b h i sel j d' , j = fk .shape [- 2 ], d = fk .shape [- 1 ])
400399
401- fk = rearrange ( fk , 'b h i w j d -> b h i (w j) d' )
402- fv = rearrange ( fv , 'b h i w j d -> b h i (w j) d' )
400+ fk = fk . gather ( 3 , selected_block_indices )
401+ fv = fv . gather ( 3 , selected_block_indices )
403402
404- # fine attention
403+ # handle maybe gating
405404
406- fk , fv , fmask = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , fmask ))
405+ if self .use_diff_topk :
406+ gates = F .pad (gates , (0 , 1 ), value = 1. )
407407
408- fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
408+ fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
409+ fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
409410
410- fsim = fsim .masked_fill (~ fmask , mask_value )
411+ fk = rearrange (fk , 'b h i w j d -> b h i (w j) d' )
412+ fv = rearrange (fv , 'b h i w j d -> b h i (w j) d' )
411413
412- fattn = fsim .softmax (dim = - 1 )
414+ # fine attention
415+
416+ fk , fv , fmask = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , fmask ))
417+
418+ fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
419+
420+ fsim = fsim .masked_fill (~ fmask , mask_value )
421+
422+ fattn = fsim .softmax (dim = - 1 )
413423
414- fine_attn_out = einsum (fattn , fv , 'b h i j, b h i j d -> b h i d' )
424+ fine_attn_out = einsum (fattn , fv , 'b h i j, b h i j d -> b h i d' )
415425
416- fine_attn_out = fine_attn_out [..., :seq_len , :]
426+ fine_attn_out = fine_attn_out [..., :seq_len , :]
417427 else :
418428 # if only first block, just do a simple block causal
419429
0 commit comments