@@ -252,6 +252,7 @@ def __init__(
252252
253253 self .split_compress_window = Rearrange ('b h (w n) d -> b h w n d' , n = compress_block_size )
254254
255+ self .num_mem_compress_kv = num_compressed_mem_kv
255256 self .compress_mem_kv = nn .Parameter (torch .zeros (2 , kv_heads , num_compressed_mem_kv , dim_head ))
256257
257258 self .k_intrablock_positions = nn .Parameter (torch .zeros (kv_heads , compress_block_size , dim_head ))
@@ -332,7 +333,6 @@ def forward_inference(
332333
333334 sliding_window = self .sliding_window_size
334335 compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
335- num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
336336
337337 fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
338338 num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
@@ -361,6 +361,14 @@ def forward_inference(
361361 ck = cache_ck
362362 cv = cache_cv
363363
364+ repeated_ck = repeat (ck , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
365+ repeated_cv = repeat (cv , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
366+
367+ csim = einsum (q , repeated_ck , 'b h i d, b h j d -> b h i j' ) * scale
368+ cattn = csim .softmax (dim = - 1 )
369+
370+ compressed_attn_out = einsum (cattn , repeated_cv , 'b h i j, b h j d -> b h i d' )
371+
364372 if divisible_by (seq_len , self .compress_block_size ):
365373 k_compress_input = self .split_compress_window (k [..., - self .compress_block_size :, :] + self .k_intrablock_positions )
366374 v_compress_input = self .split_compress_window (v [..., - self .compress_block_size :, :] + self .v_intrablock_positions )
@@ -374,17 +382,64 @@ def forward_inference(
374382 if return_cache :
375383 cache_compressed_kv = (ck , cv )
376384
377- ck = repeat (ck , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
378- cv = repeat (cv , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
385+ # 2. fine attention inference (todo - compress and fine diff block sizes)
379386
380- csim = einsum (q , ck , 'b h i d, b h j d -> b h i j' ) * scale
381- cattn = csim .softmax (dim = - 1 )
387+ assert self .compress_block_size == self .selection_block_size
388+
389+ importance_scores = csim [..., self .num_mem_compress_kv :]
390+ importance_scores += torch .randn_like (importance_scores ) * 100
391+
392+ num_compress_blocks = importance_scores .shape [- 1 ]
393+ num_selected = min (self .num_selected_blocks , num_compress_blocks )
394+ has_selected_kv_for_fine_attn = num_selected > 0
395+
396+ # block causal diagonal
397+
398+ fine_sliding_window = (seq_len % self .selection_block_size ) + 1
399+ fk = k [..., - fine_sliding_window :, :]
400+ fv = v [..., - fine_sliding_window :, :]
401+
402+ # select out the sparse kv segments as defined by compressed attention map as importance score
403+
404+ if has_selected_kv_for_fine_attn :
405+ if self .query_heads_share_selected_kv :
406+ importance_scores = reduce (importance_scores , 'b (h grouped_queries) ... -> b h ...' , 'mean' , grouped_queries = self .num_grouped_queries )
407+
408+ sel_scores , sel_indices = importance_scores .topk (num_selected , dim = - 1 )
409+
410+ fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
411+ remainder = fine_divisible_seq_len - k .shape [- 2 ]
412+
413+ sel_fk = pad_at_dim (k , (0 , remainder ), dim = - 2 )
414+ sel_fv = pad_at_dim (v , (0 , remainder ), dim = - 2 )
415+
416+ sel_fk = rearrange (sel_fk , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
417+ sel_fv = rearrange (sel_fv , 'b h (w j) d -> b h w j d' , j = self .selection_block_size )
418+
419+ sel_fk = einx .get_at ('b h [w] j d, b h 1 sel -> b h (sel j) d' , sel_fk , sel_indices )
420+ sel_fv = einx .get_at ('b h [w] j d, b h 1 sel -> b h (sel j) d' , sel_fv , sel_indices )
421+
422+ fmask = sel_scores > 1e-10
423+
424+ fmask = repeat (fmask , 'b h i sel -> b h i (sel j)' , j = self .selection_block_size )
425+
426+ fk = cat ((sel_fk , fk ), dim = - 2 )
427+ fv = cat ((sel_fv , fv ), dim = - 2 )
428+
429+ fmask = F .pad (fmask , (0 , fk .shape [- 2 ] - fmask .shape [- 1 ]), value = True )
430+
431+ # remove later
432+
433+ fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
434+
435+ fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
382436
383- compressed_attn_out = einsum ( cattn , cv , 'b h i j, b h j d -> b h i d' )
437+ fsim = einx . where ( 'b h i j, b h gh i j, -> b h gh i j' , fmask , fsim , max_neg_value ( fsim ) )
384438
385- # 2. fine attention inference (todo )
439+ fattn = fsim . softmax ( dim = - 1 )
386440
387- # not implemented
441+ fine_attn_out = einsum (fattn , fv , 'b h gh i j, b h j d -> b h gh i d' )
442+ fine_attn_out = rearrange (fine_attn_out , 'b h gh ... -> b (h gh) ...' )
388443
389444 # 3. sliding window
390445
@@ -402,7 +457,7 @@ def forward_inference(
402457
403458 strategy_weighted_combine = self .to_strategy_combine (inp )
404459
405- out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , sliding_window_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
460+ out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , compressed_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
406461
407462 # merge heads and combine them
408463
0 commit comments