@@ -113,6 +113,9 @@ def round_up_mult(n, mult):
113113def divisible_by (num , den ):
114114 return (num % den ) == 0
115115
116+ def max_neg_value (t ):
117+ return - torch .finfo (t .dtype ).max
118+
116119def pack_one_with_inverse (t , pattern ):
117120 packed , ps = pack ([t ], pattern )
118121 def inverse (out ):
@@ -142,7 +145,7 @@ def straight_through(t, target):
142145def attend (
143146 q , k , v ,
144147 mask = None ,
145- return_attn = False ,
148+ return_sim = False ,
146149 scale = None
147150):
148151 scale = default (scale , q .shape [- 1 ] ** - 0.5 )
@@ -154,7 +157,7 @@ def attend(
154157
155158 sim = einsum (q , k , 'b h qh i d, b h j d -> b h qh i j' ) * scale
156159
157- mask_value = - torch . finfo (sim . dtype ). max
160+ mask_value = max_neg_value (sim )
158161
159162 if exists (mask ):
160163 sim = sim .masked_fill (~ mask , mask_value )
@@ -165,12 +168,12 @@ def attend(
165168
166169 attn_out = rearrange (attn_out , 'b h qh ... -> b (h qh) ...' )
167170
168- if not return_attn :
171+ if not return_sim :
169172 return attn_out
170173
171- attn = rearrange (attn , 'b h qh ... -> b (h qh) ...' )
174+ sim = rearrange (sim , 'b h qh ... -> b (h qh) ...' )
172175
173- return attn_out , attn
176+ return attn_out , sim
174177
175178# classes
176179
@@ -360,17 +363,17 @@ def forward(
360363
361364 cmask = einx .less ('j, i -> i j' , ck_seq , cq_seq )
362365
363- compressed_attn_out , cattn = attend (cq , ck , cv , mask = cmask , return_attn = True )
366+ compressed_attn_out , csim = attend (cq , ck , cv , mask = cmask , return_sim = True )
364367
365368 # for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
366369
367370 rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k )
368371
369372 # 2. fine attention over selected based on compressed attention logits - variables prepended with `f` stands for the fine attention pathway
370373
371- importance_scores = cattn [..., num_mem_compress_kv :]
374+ importance_scores = csim [..., num_mem_compress_kv :]
372375
373- num_selected = min (self .num_selected_blocks , importance_scores . shape [ - 1 ] )
376+ num_selected = min (self .num_selected_blocks , num_compress_blocks )
374377 has_selected_kv_for_fine_attn = num_selected > 0
375378
376379 # maybe average the compressed attention across each grouped queries (per key / values)
@@ -386,32 +389,37 @@ def forward(
386389 # cannot parse their equation, so will just improvise
387390 # first we expand all the compressed scores to the full sequence length, then average within each fine / selection block size - pad on the right to 0s, which should be fine as sliding window convers the local anyways
388391
389- if has_selected_kv_for_fine_attn and self . compress_block_size != self . selection_block_size :
392+ if has_selected_kv_for_fine_attn :
390393
391- score_len = importance_scores .shape [- 1 ]
392- compress_seq_len = score_len * self .compress_block_size
394+ if self .compress_block_size != self .selection_block_size :
393395
394- if self .interpolated_importance_score :
395- importance_scores = interpolate_1d (importance_scores , compress_seq_len )
396- else :
397- importance_scores = repeat (importance_scores , '... j -> ... (j block_size)' , block_size = self .compress_block_size )
396+ compress_seq_len = num_compress_blocks * self .compress_block_size
397+
398+ if self .interpolated_importance_score :
399+ importance_scores = interpolate_1d (importance_scores , compress_seq_len )
400+ else :
401+ importance_scores = repeat (importance_scores , '... j -> ... (j block_size)' , block_size = self .compress_block_size )
402+
403+ padding = fine_divisible_seq_len - compress_seq_len
398404
399- padding = fine_divisible_seq_len - compress_seq_len
405+ fine_query_seq_len = importance_scores .shape [- 2 ]
406+ fine_query_padding = fine_divisible_seq_len - importance_scores .shape [- 2 ]
400407
401- fine_query_seq_len = importance_scores .shape [- 2 ]
402- fine_query_padding = fine_divisible_seq_len - importance_scores .shape [- 2 ]
408+ importance_scores = F .pad (importance_scores , (0 , padding ))
403409
404- importance_scores = F . pad ( importance_scores , ( 0 , padding ))
410+ # mask out the diagonal since block causal is included by default for fine attending
405411
406- # mask out the diagonal since block causal is included by default for fine attending
412+ block_causal_mask = torch .ones ((num_fine_blocks ,) * 2 , device = device , dtype = torch .bool ).tril (- 1 )
413+ block_causal_mask = repeat (block_causal_mask , 'i j -> (i n1) (j n2)' , n1 = self .selection_block_size , n2 = self .selection_block_size )
414+ block_causal_mask = block_causal_mask [:fine_query_seq_len ]
407415
408- block_causal_mask = torch .ones ((num_fine_blocks ,) * 2 , device = device , dtype = torch .bool ).tril (- 1 )
409- block_causal_mask = repeat (block_causal_mask , 'i j -> (i n1) (j n2)' , n1 = self .selection_block_size , n2 = self .selection_block_size )
410- block_causal_mask = block_causal_mask [:fine_query_seq_len ]
416+ importance_scores = importance_scores .masked_fill (~ block_causal_mask , max_neg_value (csim ))
411417
412- importance_scores = importance_scores . masked_fill ( ~ block_causal_mask , 0. )
418+ importance_scores = reduce ( importance_scores , '... (j block_size) -> ... j' , 'mean' , block_size = self . selection_block_size )
413419
414- importance_scores = reduce (importance_scores , '... (j block_size) -> ... j' , 'mean' , block_size = self .selection_block_size )
420+ importance_scores = F .pad (importance_scores , (1 , 0 ), value = - 1e3 )
421+ importance_scores = importance_scores .softmax (dim = - 1 )
422+ importance_scores = importance_scores [..., 1 :]
415423
416424 # handle if number of total blocks is less than number to select for fine attention
417425
@@ -496,7 +504,7 @@ def forward(
496504
497505 fsim = einsum (fq , fk , 'b h qh i d, b h i j d -> b h qh i j' ) * self .scale
498506
499- mask_value = - torch . finfo (fsim . dtype ). max
507+ mask_value = max_neg_value (fsim )
500508
501509 fsim = fsim .masked_fill (~ fmask , mask_value )
502510
0 commit comments