4141
4242# flex attn sliding attention mask
4343
44- def create_sliding_mask (seq_len , window_size ):
44+ def create_sliding_mask (seq_len , window_size , causal = True ):
4545 def sliding_mask (_ , __ , q_idx , kv_idx ):
46- causal_mask = q_idx >= kv_idx
4746
48- sliding_mask = ( q_idx - kv_idx ) <= window_size
49- causal_mask = causal_mask & sliding_mask
47+ distance = q_idx - kv_idx
48+ mask = distance <= window_size
5049
51- return causal_mask
50+ if causal :
51+ mask = mask & q_idx >= kv_idx
52+ else :
53+ mask = mask & (distance >= - window_size )
54+
55+ return mask
5256
5357 block_mask = create_block_mask (sliding_mask , B = None , H = None , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
5458 return block_mask
5559
56- def create_compress_mask (seq_len , kv_seq_len , compress_block_size , mem_kv_len = 0 ):
60+ def create_compress_mask (seq_len , kv_seq_len , compress_block_size , mem_kv_len = 0 , causal = True ):
61+
62+ if not causal :
63+ return None
64+
5765 # cannot be used as using attention logits for importance score
5866 # but just to show the immense potential of flex attention
5967
@@ -69,7 +77,7 @@ def compress_mask(_, __, q_idx, kv_idx):
6977 block_mask = create_block_mask (compress_mask , B = None , H = None , Q_LEN = seq_len , KV_LEN = kv_seq_len + mem_kv_len , _compile = True )
7078 return block_mask
7179
72- def create_fine_mask (seq_len , fine_block_size ):
80+ def create_fine_mask (seq_len , fine_block_size , causal = True ):
7381
7482 def inner (selected_block_indices : Tensor , num_grouped_queries = 1 ):
7583 device = selected_block_indices .device
@@ -86,6 +94,9 @@ def fine_mask(b_idx, h_idx, q_idx, kv_idx):
8694
8795 is_selected = one_hot_selected_block_indices [b_idx , kv_head_idx , q_idx , compressed_kv_idx ]
8896
97+ if not causal :
98+ return is_selected
99+
89100 causal_mask = q_idx >= kv_idx
90101 block_diagonal = compressed_q_idx == compressed_kv_idx
91102
@@ -189,6 +200,7 @@ def __init__(
189200 num_selected_blocks ,
190201 kv_heads = None ,
191202 num_compressed_mem_kv = 1 ,
203+ causal = False ,
192204 norm = True ,
193205 use_diff_topk = False ,
194206 use_triton_kernel = False ,
@@ -219,6 +231,10 @@ def __init__(
219231
220232 self .norm = nn .RMSNorm (dim ) if norm else nn .Identity ()
221233
234+ # autoregressive or not - will extend this work for long context video / genomics use-cases
235+
236+ self .causal = causal
237+
222238 # rotary
223239
224240 self .rotary_emb = RotaryEmbedding (dim_head )
@@ -236,7 +252,7 @@ def __init__(
236252 self .sliding_window = LocalAttention (
237253 dim = dim_head ,
238254 window_size = sliding_window_size ,
239- causal = True ,
255+ causal = causal ,
240256 exact_windowsize = True ,
241257 autopad = True ,
242258 use_rotary_pos_emb = False
@@ -322,6 +338,8 @@ def forward_inference(
322338 cache ,
323339 return_cache = True
324340 ):
341+ assert self .causal , 'inference only relevant for autoregressive'
342+
325343 # destruct cache
326344
327345 (
@@ -515,6 +533,8 @@ def forward(
515533 assert inp .shape [1 ] == 1 , 'input must be single tokens if inferencing with cache key values'
516534 return self .forward_inference (inp , cache , return_cache = return_cache )
517535
536+ assert not (self .causal and return_cache )
537+
518538 batch , seq_len , scale , heads , device = * inp .shape [:2 ], self .scale , self .heads , inp .device
519539
520540 compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
@@ -560,11 +580,16 @@ def forward(
560580 ck = cat ((mem_ck , ck ), dim = - 2 )
561581 cv = cat ((mem_cv , cv ), dim = - 2 )
562582
563- cq_seq = arange ( seq_len , device = device )
564- ck_seq = (( arange ( num_compress_blocks , device = device ) + 1 ) * self . compress_block_size ) - 1
565- ck_seq = F . pad ( ck_seq , ( num_mem_compress_kv , 0 ), value = - 1 )
583+ # compressed masking
584+
585+ cmask = None
566586
567- cmask = einx .less ('j, i -> i j' , ck_seq , cq_seq )
587+ if self .causal :
588+ cq_seq = arange (seq_len , device = device )
589+ ck_seq = ((arange (num_compress_blocks , device = device ) + 1 ) * self .compress_block_size ) - 1
590+ ck_seq = F .pad (ck_seq , (num_mem_compress_kv , 0 ), value = - 1 )
591+
592+ cmask = einx .less ('j, i -> i j' , ck_seq , cq_seq )
568593
569594 compressed_attn_out , csim = attend (cq , ck , cv , mask = cmask , return_sim = True )
570595
@@ -657,7 +682,8 @@ def forward(
657682 self .selection_block_size ,
658683 selected_block_indices ,
659684 fmask ,
660- sel_scale = gates
685+ sel_scale = gates ,
686+ include_block_diagonal = self .causal
661687 )
662688
663689 elif exists (fine_selection_flex_mask ):
@@ -685,19 +711,23 @@ def forward(
685711 if exists (gates ):
686712 gates = pad_at_dim (gates , (0 , remainder ), value = 0 , dim = - 2 )
687713
688- # handle block causal diagonal in the diagram, but run experiments without to see
714+ if self .causal :
715+ # handle block causal diagonal in the diagram, but run experiments without to see
716+
717+ fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
718+ fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = selected_block_indices .shape [1 ])
719+ selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
689720
690- fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
691- fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = selected_block_indices .shape [1 ])
692- selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
721+ fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
693722
694- fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
723+ causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
724+ causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = fmask .shape [1 ])
695725
696- causal_mask = torch . ones (( self . selection_block_size ,) * 2 , device = device , dtype = torch . bool ). tril ( )
697- causal_mask = repeat ( causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = fmask . shape [ 1 ] )
726+ fmask = cat (( fmask , causal_mask ), dim = - 2 )
727+ fmask = rearrange ( fmask , 'b h i w j -> b h 1 i (w j)' )
698728
699- fmask = cat (( fmask , causal_mask ), dim = - 2 )
700- fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
729+ else :
730+ fmask = repeat (fmask , 'b h i w -> b h 1 i (w j)' , j = self . selection_block_size )
701731
702732 # select out the spatial crops of keys / values for fine attention
703733
@@ -721,7 +751,9 @@ def forward(
721751 # differential topk gating
722752
723753 if self .use_diff_topk :
724- gates = F .pad (gates , (0 , 1 ), value = 1. )
754+ if self .causal :
755+ gates = F .pad (gates , (0 , 1 ), value = 1. )
756+
725757 fk = einx .multiply ('b h i sel, b h i sel j d -> b h i sel j d' , gates , fk )
726758
727759 # merge selected key values
@@ -730,8 +762,6 @@ def forward(
730762
731763 # fine attention
732764
733- fmask = rearrange (fmask , 'b h ... -> b h 1 ...' )
734-
735765 fq = rearrange (fq , 'b (h qh) ... -> b h qh ...' , qh = fine_num_grouped_queries )
736766
737767 fsim = einsum (fq , fk , 'b h qh i d, b h i j d -> b h qh i j' ) * self .scale
@@ -752,7 +782,10 @@ def forward(
752782 # if only first block, just do a simple block causal
753783
754784 seq_len = fk .shape [- 2 ]
755- fmask = causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).tril ()
785+ fmask = None
786+
787+ if self .causal :
788+ fmask = causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).tril ()
756789
757790 fine_attn_out = attend (fq , fk , fv , mask = fmask )
758791
0 commit comments