66
77# Please refer to README.md in the same folder for more information.
88
9+ import logging
10+ from collections import defaultdict , deque
911from dataclasses import dataclass
1012from functools import partial
1113from typing import Dict , List , Optional , Tuple
2325
2426from torch import nn
2527
28+ logger = logging .getLogger (__name__ )
29+
2630
2731def find_multiple (n : int , k : int ) -> int :
2832 if n % k == 0 :
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507511
508512
509513class InputManager :
514+ class NGramCache :
515+ def __init__ (self , max_size : int ):
516+ self .cache = deque ()
517+ self .max_size = max_size
518+
519+ def add (self , ngram : List [int ]):
520+ if ngram in self .cache :
521+ return
522+ if len (self .cache ) == self .max_size :
523+ self .cache .popleft ()
524+ self .cache .append (ngram )
525+
526+ def __iter__ (self ):
527+ return iter (self .cache )
528+
529+ def __str__ (self ):
530+ return str (self .cache )
531+
510532 def __init__ (
511533 self ,
512534 n_layers : int ,
@@ -519,6 +541,7 @@ def __init__(
519541 dtype = torch .float16 ,
520542 minus_infinity = - torch .inf ,
521543 cache_size = None ,
544+ lookahead_enabled : bool = False ,
522545 ):
523546 if cache_size is None :
524547 cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532555
533556 self .seq_length = seq_length
534557 self .use_cache_list = use_cache_list
558+ self .lookahead_enabled = lookahead_enabled
559+ self .minus_infinity = minus_infinity
535560
536561 if self .use_cache_list :
537562 self .k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609634 if self .cache_pos == self .cache_size :
610635 self .cache_pos = 0
611636
612- def update (self , input_length , new_k_caches , new_v_caches ):
637+ def update (self , input_length , new_k_caches , new_v_caches , update_pos = 0 ):
613638 # Copy as much new cache data into cache as possible without wrapping
614639 amount_to_copy = min (input_length , self .cache_size - self .cache_pos )
615- self ._update_cache (0 , amount_to_copy , new_k_caches , new_v_caches )
640+ self ._update_cache (update_pos , amount_to_copy , new_k_caches , new_v_caches )
616641 if self .input_pos <= self .cache_size :
617642 self .attn_mask [:, (self .input_pos ) : (self .input_pos + amount_to_copy )] = (
618643 0.0
@@ -625,7 +650,7 @@ def update(self, input_length, new_k_caches, new_v_caches):
625650 )
626651 if remaining_to_copy > 0 :
627652 self ._update_cache (
628- amount_to_copy , remaining_to_copy , new_k_caches , new_v_caches
653+ update_pos + amount_to_copy , remaining_to_copy , new_k_caches , new_v_caches
629654 )
630655
631656 self .input_pos += input_length
@@ -661,3 +686,192 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661686 self .get_inputs (tokens [0 :processed_tokens ]),
662687 tokens [processed_tokens :],
663688 )
689+
690+ def _get_lookahead_decoding_mask (
691+ self , ngram_size : int , window_size : int , n_verifications : int
692+ ) -> torch .Tensor :
693+ mask = torch .full ((self .seq_length , self .seq_length ), self .minus_infinity )
694+ mask [0 ][0 ] = 0.0
695+
696+ lookahead_submask = torch .triu (
697+ torch .full ((window_size , window_size ), self .minus_infinity ),
698+ diagonal = 1 ,
699+ )
700+ for i in range (ngram_size - 1 ):
701+ offset = window_size * i
702+ mask [offset : offset + window_size , :window_size ] = lookahead_submask
703+ for j in range (1 , i + 1 ):
704+ mask [
705+ offset : offset + window_size ,
706+ window_size * j : window_size * (j + 1 ),
707+ ].fill_diagonal_ (0.0 )
708+
709+ verification_offset = max (window_size * (ngram_size - 1 ), 1 )
710+ verification_submask = torch .triu (
711+ torch .full ((ngram_size - 1 , ngram_size - 1 ), self .minus_infinity ),
712+ diagonal = 1 ,
713+ )
714+ for i in range (n_verifications ):
715+ mask [
716+ verification_offset + i * (ngram_size - 1 ) : verification_offset
717+ + (i + 1 ) * (ngram_size - 1 ),
718+ verification_offset + i * (ngram_size - 1 ) : verification_offset
719+ + (i + 1 ) * (ngram_size - 1 ),
720+ ] = verification_submask
721+ mask [verification_offset :, :1 ] = 0.0
722+
723+ return mask
724+
725+ def _get_lookahead_position_offsets (
726+ self , ngram_size : int , window_size : int , n_verifications : int
727+ ) -> torch .Tensor :
728+ pos_offsets = torch .zeros (self .seq_length , dtype = torch .int32 )
729+ idx = 0
730+ if window_size > 0 :
731+ for i in range (ngram_size - 1 ):
732+ for j in range (window_size ):
733+ pos_offsets [idx ] = i + j
734+ idx += 1
735+ else :
736+ pos_offsets [0 ] = 0
737+ idx += 1
738+
739+ # Verification branches: [1, 2, ..., ngram_size - 1].
740+ for _ in range (n_verifications ):
741+ for j in range (1 , ngram_size ):
742+ pos_offsets [idx ] = j
743+ idx += 1
744+
745+ return pos_offsets
746+
747+ def lookahead_decode (
748+ self ,
749+ model ,
750+ init_token : int ,
751+ n : int ,
752+ ngram_size : int ,
753+ window_size : int ,
754+ n_verifications : int ,
755+ stop_tokens : Optional [List [int ]] = None ,
756+ ngram_caches : Optional [Dict [int , "InputManager.NGramCache" ]] = None ,
757+ ) -> List [int ]:
758+ if not self .lookahead_enabled :
759+ raise RuntimeError ("Lookahead decoding is not enabled" )
760+
761+ if (ngram_size - 1 ) * (window_size + n_verifications ) > self .seq_length :
762+ raise RuntimeError (
763+ f"Lookahead decoding configuration not compatible with seq_length { self .seq_length } . "
764+ f"Required: { (ngram_size - 1 ) * (window_size + n_verifications )} "
765+ )
766+
767+ self .attn_mask [:, self .cache_size :] = self ._get_lookahead_decoding_mask (
768+ ngram_size , window_size , n_verifications
769+ )
770+ logger .debug ("Lookahead decoding mask: " )
771+ for i in range (self .seq_length ):
772+ logger .debug (
773+ " " .join (
774+ ("X" if x == 0.0 else " " )
775+ for x in self .attn_mask [i ][self .cache_size :]
776+ )
777+ )
778+
779+ offsets = self ._get_lookahead_position_offsets (
780+ ngram_size , window_size , n_verifications
781+ )
782+
783+ stop_tokens = stop_tokens or []
784+ verification_offset = window_size * (ngram_size - 1 )
785+
786+ if ngram_caches is None :
787+ ngram_caches = defaultdict (lambda : InputManager .NGramCache (n_verifications ))
788+ new_tokens = [init_token ]
789+ x = [init_token ] * self .seq_length
790+ inference_count = 0
791+
792+ while len (new_tokens ) < n + 1 :
793+ cache = ngram_caches [x [0 ]]
794+ for i , ngram in enumerate (cache ):
795+ for j , token in enumerate (ngram ):
796+ x [verification_offset + i * (ngram_size - 1 ) + j ] = token
797+
798+ logits , new_k , new_v = model (
799+ tokens = torch .tensor ([x ], dtype = torch .int64 ),
800+ input_pos = torch .tensor ([self .input_pos ], dtype = torch .long ),
801+ k_caches = self .k_caches ,
802+ v_caches = self .v_caches ,
803+ attn_mask = self .attn_mask ,
804+ input_len = torch .tensor ([len (x )], dtype = torch .long ),
805+ rope_indices = self .input_pos + offsets ,
806+ )
807+ inference_count += 1
808+
809+ # Greedy only
810+ y = logits [0 ].argmax (dim = - 1 ).tolist ()
811+ new_tokens .append (y [0 ])
812+ logger .debug (f"{ self .input_pos } : x = { x [0 ]} , y = { y [0 ]} " )
813+ if new_tokens [- 1 ] in stop_tokens :
814+ break
815+
816+ # Collect new n-grams.
817+ for i in range (window_size ):
818+ key = x [i ]
819+ suffix = []
820+ for j in range (1 , ngram_size - 1 ):
821+ suffix .append (x [i + j * window_size ])
822+ suffix .append (y [i + window_size * (ngram_size - 2 )])
823+ ngram_caches [key ].add (suffix )
824+
825+ # Verification.
826+ longest_match = []
827+ matched_branch = None
828+ for i in range (n_verifications ):
829+ match = [y [0 ]]
830+ j = 0
831+ # for j in range(ngram_size - 1):
832+ while (
833+ j < ngram_size - 1
834+ and x [verification_offset + (ngram_size - 1 ) * i + j ] == match [- 1 ]
835+ ):
836+ match .append (y [verification_offset + (ngram_size - 1 ) * i + j ])
837+ j += 1
838+ if len (match ) - 1 > len (longest_match ):
839+ longest_match = match [1 :]
840+ matched_branch = i
841+
842+ if matched_branch is not None :
843+ logger .debug (
844+ f"Matched { len (longest_match )} additional tokens from n-grams: { longest_match } "
845+ )
846+ for stop in stop_tokens :
847+ if stop in longest_match :
848+ longest_match = longest_match [: longest_match .index (stop ) + 1 ]
849+
850+ new_tokens .extend (longest_match )
851+ branch_offset = verification_offset + (ngram_size - 1 ) * matched_branch
852+ self .update (
853+ input_length = len (longest_match ),
854+ new_k_caches = new_k ,
855+ new_v_caches = new_v ,
856+ update_pos = branch_offset ,
857+ )
858+ else :
859+ self .update (input_length = 1 , new_k_caches = new_k , new_v_caches = new_v )
860+
861+ # Update lookahead branch.
862+ for i in range (ngram_size - 2 ):
863+ for j in range (window_size ):
864+ x [window_size * i + j ] = x [window_size * (i + 1 ) + j ]
865+ for j in range (window_size ):
866+ x [window_size * (ngram_size - 2 ) + j ] = y [
867+ window_size * (ngram_size - 2 ) + j
868+ ]
869+
870+ x [0 ] = new_tokens [- 1 ]
871+ if new_tokens [- 1 ] in stop_tokens :
872+ break
873+
874+ logger .info (
875+ f"Generated { len (new_tokens ) - 1 } tokens with { inference_count } inference(s)."
876+ )
877+ return new_tokens
0 commit comments