@@ -315,13 +315,121 @@ def __init__(
315315
316316 self .combine_heads = nn .Linear (dim_inner , dim , bias = False )
317317
318+ def forward_inference (
319+ self ,
320+ inp ,
321+ cache ,
322+ return_cache = True
323+ ):
324+ # destruct cache
325+
326+ (cache_k , cache_v ), (cache_ck , cache_cv ) = cache
327+
328+ # variables
329+
330+ batch , scale , heads , device = inp .shape [0 ], self .scale , self .heads , inp .device
331+ seq_len = cache_k .shape [- 2 ] + 1
332+
333+ sliding_window = self .sliding_window_size
334+ compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
335+ num_compress_blocks = compress_divisible_seq_len // self .compress_block_size
336+
337+ fine_divisible_seq_len = round_up_mult (seq_len , self .selection_block_size )
338+ num_fine_blocks = fine_divisible_seq_len // self .selection_block_size
339+
340+ # maybe prenorm
341+
342+ inp = self .norm (inp )
343+
344+ # queries, keys, values
345+
346+ q , k , v = self .to_qkv (inp ).split (self .qkv_split , dim = - 1 )
347+
348+ q , k , v = map (self .split_heads , (q , k , v ))
349+
350+ # handle cache
351+
352+ k = cat ((cache_k , k ), dim = - 2 )
353+ v = cat ((cache_v , v ), dim = - 2 )
354+
355+ if return_cache :
356+ cache_kv = (k , v )
357+
358+ # 1. compressed attn inference
359+
360+ cq = q
361+ ck = cache_ck
362+ cv = cache_cv
363+
364+ if divisible_by (seq_len , self .compress_block_size ):
365+ k_compress_input = self .split_compress_window (k [..., - self .compress_block_size :, :] + self .k_intrablock_positions )
366+ v_compress_input = self .split_compress_window (v [..., - self .compress_block_size :, :] + self .v_intrablock_positions )
367+
368+ next_ck = self .k_compress (k_compress_input )
369+ next_cv = self .v_compress (v_compress_input )
370+
371+ ck = cat ((ck , next_ck ), dim = - 2 )
372+ cv = cat ((cv , next_cv ), dim = - 2 )
373+
374+ if return_cache :
375+ cache_compressed_kv = (ck , cv )
376+
377+ ck = repeat (ck , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
378+ cv = repeat (cv , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
379+
380+ csim = einsum (q , ck , 'b h i d, b h j d -> b h i j' ) * scale
381+ cattn = csim .softmax (dim = - 1 )
382+
383+ compressed_attn_out = einsum (cattn , cv , 'b h i j, b h j d -> b h i d' )
384+
385+ # 2. fine attention inference (todo)
386+
387+ # not implemented
388+
389+ # 3. sliding window
390+
391+ k = repeat (k , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
392+ v = repeat (v , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
393+
394+ sliding_slice = (Ellipsis , slice (- (sliding_window + 1 ), None ), slice (None ))
395+ rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k [sliding_slice ])
396+
397+ sim = einsum (rotated_q , rotated_k , 'b h i d, b h j d -> b h i j' ) * scale
398+ attn = sim .softmax (dim = - 1 )
399+ sliding_window_attn_out = einsum (attn , v [sliding_slice ], 'b h i j, b h j d -> b h i d' )
400+
401+ # combine strategies
402+
403+ strategy_weighted_combine = self .to_strategy_combine (inp )
404+
405+ out = einsum (strategy_weighted_combine , stack ([compressed_attn_out , sliding_window_attn_out , sliding_window_attn_out ]), 'b h n s, s b h n d -> b h n d' )
406+
407+ # merge heads and combine them
408+
409+ out = self .merge_heads (out )
410+
411+ out = self .combine_heads (out )
412+
413+ if not return_cache :
414+ return out
415+
416+ return out , (cache_kv , cache_compressed_kv )
417+
318418 def forward (
319419 self ,
320420 inp ,
421+ cache = None ,
321422 disable_triton_kernel = False ,
322423 sliding_window_flex_mask = None ,
323- fine_selection_flex_mask = None
424+ fine_selection_flex_mask = None ,
425+ return_cache = False
324426 ):
427+ is_inferencing = exists (cache )
428+
429+ if is_inferencing :
430+ assert inp .shape [1 ] == 1 , 'input must be single tokens if inferencing with cache key values'
431+ return self .forward_inference (inp , cache , return_cache = return_cache )
432+
325433 batch , seq_len , scale , heads , device = * inp .shape [:2 ], self .scale , self .heads , inp .device
326434
327435 compress_divisible_seq_len = round_down_mult (seq_len , self .compress_block_size )
@@ -340,6 +448,11 @@ def forward(
340448
341449 q , k , v = map (self .split_heads , (q , k , v ))
342450
451+ # handle cache
452+
453+ if return_cache :
454+ cache_kv = (k , v )
455+
343456 # compressed key / values - variables prepended with `c` stands for compressed
344457
345458 k_pos = repeat (self .k_intrablock_positions , 'h n d -> h (r n) d' , r = num_compress_blocks )
@@ -352,6 +465,9 @@ def forward(
352465 ck = self .k_compress (k_compress_input ) # Equation (7) of the Native Sparse Attention paper
353466 cv = self .v_compress (v_compress_input )
354467
468+ if return_cache :
469+ cache_compressed_kv = (ck , cv )
470+
355471 # 1. coarse attention over compressed
356472
357473 mem_ck , mem_cv = repeat (self .compress_mem_kv , 'kv ... -> kv b ...' , b = batch )
@@ -570,4 +686,9 @@ def forward(
570686
571687 out = self .merge_heads (out )
572688
573- return self .combine_heads (out )
689+ out = self .combine_heads (out )
690+
691+ if not return_cache :
692+ return out
693+
694+ return out , (cache_kv , cache_compressed_kv )
0 commit comments