@@ -90,6 +90,43 @@ def is_final_state(self, state: Any) -> bool:
9090 def copy (self ) -> "Guide" :
9191 ...
9292
93+ def accepts (self , token_ids : List [int ], state = None ) -> bool :
94+ """
95+ Determine whether the sequence, `token_ids`, is accepted by the Guide.
96+ `token_ids` doesn't need to complete the guide to be accepted.
97+ """
98+ try :
99+ self .derive (token_ids , state )
100+ return True
101+ except ValueError :
102+ return False
103+
104+ def derive (self , token_ids : List [int ], state = None ) -> Union ["Guide" , None ]:
105+ """
106+ TODO: Docstring
107+ """
108+ if state is None :
109+ state = self .initial_state
110+ for token_id in token_ids :
111+ instruction = self .get_next_instruction (state )
112+
113+ # determine if token_id allowed by instruction
114+ if isinstance (instruction , Write ):
115+ raise NotImplementedError ("TODO" )
116+ elif isinstance (instruction , Generate ):
117+ if (
118+ instruction .tokens is not None
119+ and token_id not in instruction .tokens
120+ ):
121+ raise ValueError ("Cannot advance state with provided token_ids" )
122+ else :
123+ raise TypeError (f"Expected instruction, got { instruction } " )
124+
125+ # advance state
126+ state = self .get_next_state (state , token_id )
127+
128+ return state
129+
93130
94131class StopAtEOSGuide (Guide ):
95132 """Guide to generate tokens until the EOS token has been generated."""
@@ -487,3 +524,214 @@ def must_terminate_state(self, state: CFGState) -> bool:
487524 def copy (self ) -> "CFGGuide" :
488525 """Create a copy of the Guide."""
489526 return CFGGuide (self .cfg_string , self .tokenizer )
527+
528+
529+ @cache ()
530+ def build_vocab_prefix_map (tokenizer : "Tokenizer" ) -> Dict [str , Set [Tuple [str , Tuple ]]]:
531+ """Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]"""
532+
533+ # precompute the token ids of all vocab suffixes
534+ suffixes = list (
535+ {tok [i :] for tok in tokenizer .vocabulary for i in range (1 , len (tok ))}
536+ )
537+ encoded_suffixes , _ = tokenizer .encode (suffixes )
538+ encoded_suffixes = [
539+ [tok for tok in seq_ids if tok != tokenizer .pad_token_id ]
540+ for seq_ids in encoded_suffixes .tolist ()
541+ ]
542+ suffix_map = dict (zip (suffixes , map (tuple , encoded_suffixes )))
543+ suffix_map ["" ] = tuple ()
544+
545+ # compute prefix-suffix map for all tokens, s.t. prefix + suffix = token
546+ prefix_map = collections .defaultdict (set )
547+ for token , token_id in tokenizer .vocabulary .items ():
548+ for i in range (1 , len (token ) + 1 ):
549+ prefix_map [token [:i ]].add ((token [i :], suffix_map [token [i :]]))
550+ return prefix_map
551+
552+
553+ AlignmentGuideState = collections .namedtuple (
554+ "AlignmentGuideState" , ["legal_path_map" , "child_guide_state" ]
555+ )
556+
557+
558+ class AlignmentGuide (Guide ):
559+ def __init__ (
560+ self , prompt : str , tokenizer : "Tokenizer" , child_guide : Optional [Guide ] = None
561+ ):
562+ """
563+ Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide.
564+
565+ Parameters
566+ ----------
567+ prompt : str
568+ The prompt text to be aligned with the generated tokens.
569+ tokenizer : Tokenizer
570+ Tokenizer used to align the prompt.
571+ child_guide : Guide, optional
572+ A guide to take control after alignment is complete. None -> Unconstrained after alignment
573+ """
574+ self .prompt = prompt
575+ self .tokenizer = tokenizer
576+ self .child_guide = child_guide
577+
578+ alignment_seqs , child_guide_ids = self ._get_alignment_sequences (
579+ prompt , tokenizer , child_guide
580+ )
581+ alignment_prompt_ids , common_prompt_len = self ._get_longest_common_prompt_ids (
582+ alignment_seqs
583+ )
584+
585+ self .alignment_prompt = self .tokenizer .decode (
586+ [alignment_seqs [0 , :common_prompt_len ]]
587+ )[0 ]
588+
589+ # calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens
590+ legal_paths = [
591+ tuple ([t for t in seq if t != tokenizer .pad_token_id ])
592+ for seq in alignment_seqs [:, common_prompt_len :].tolist ()
593+ ]
594+ legal_path_map = dict (zip (legal_paths , child_guide_ids ))
595+
596+ self .initial_state = AlignmentGuideState (
597+ legal_path_map = legal_path_map , child_guide_state = None
598+ )
599+
600+ @staticmethod
601+ def _get_alignment_sequences (
602+ prompt : str , tokenizer : "Tokenizer" , child_guide : Optional [Guide ] = None
603+ ):
604+ """
605+ Calculate all possible sequences which are valid with a prompt + child_guide
606+ E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid
607+
608+ Returns tuple of (alignment_seqs, child_guide_ids) of same length
609+ - alignment_seqs:
610+ All token sequences which can represent `prompt` + start of generation. The last token
611+ must represent the end of the prompt can extend beyond the prompt to start generation.
612+ Sequences are only included if the start of generation portion is legal with child guide.
613+ - child_guide_ids:
614+ Token to send to the child guide to simulate the start of generation. In the example above
615+ "world" is the last alignment seq token, therefore we must advance the state of the child
616+ guide with the tokenization of "rld" in order to continue generation with the child guide.
617+ """
618+ guide_accepts : Dict [
619+ Tuple [int ], bool
620+ ] = {} # cache of suffix acceptance for child_guide.accepts()
621+
622+ # prompts with alignment tokens at end
623+ aligned_prompt_completions : List [str ] = []
624+ # tokens to feed child guide once alignment completes
625+ child_guide_ids : List [Tuple ] = []
626+
627+ # compute alignment seqs which are valid with prompt and child guide
628+ for prefix , alignment_details in build_vocab_prefix_map (tokenizer ).items ():
629+ if prompt .endswith (prefix ):
630+ for suffix , suffix_ids in alignment_details :
631+ if child_guide is None :
632+ aligned_prompt_completions .append (prompt + suffix )
633+ child_guide_ids .append (tuple ())
634+ elif guide_accepts .setdefault (
635+ suffix_ids , child_guide .accepts (suffix_ids )
636+ ):
637+ aligned_prompt_completions .append (prompt + suffix )
638+ child_guide_ids .append (suffix_ids )
639+
640+ alignment_seqs , _ = tokenizer .encode (aligned_prompt_completions )
641+ return alignment_seqs , child_guide_ids
642+
643+ @staticmethod
644+ def _get_longest_common_prompt_ids (alignment_seqs ):
645+ """
646+ Among all candidate prompt alignment seqs, get the longest shared prefix and their length
647+ """
648+ # get longest common prefix among alignment sequences, which will form our alignment prompt
649+ common = (
650+ (alignment_seqs .unsqueeze (1 ) == alignment_seqs .unsqueeze (0 ))
651+ .all (0 )
652+ .cumprod (1 )
653+ )
654+ common_len = common .sum (1 ).max ().item ()
655+ return alignment_seqs [0 , :common_len ], common_len
656+
657+ def get_next_instruction (self , state : AlignmentGuideState ) -> Instruction :
658+ """
659+ Return the next set of valid tokens for generation based on the current state.
660+
661+ If alignment hasn't completed:
662+ tokens which continue one of the candidate alignment paths are legal
663+ If alignment has completed:
664+ get instruction from the child guide
665+ """
666+ if state .legal_path_map is not None :
667+ return Generate (
668+ sorted ({token_ids [0 ] for token_ids in state .legal_path_map .keys ()})
669+ )
670+ elif self .child_guide is None :
671+ return Generate (None )
672+ else :
673+ return self .child_guide .get_next_instruction (state .child_guide_state )
674+
675+ def get_next_state (
676+ self , state : AlignmentGuideState , token_id : int
677+ ) -> AlignmentGuideState :
678+ """
679+ Get AlignmentGuideState advanced by token ID.
680+
681+ If alignment has completed:
682+ get instruction from the child guide
683+ If alignment hasn't completed:
684+ Filter out alignment paths which don't start with token_id
685+ Remove First token from remaining paths
686+ If advancing the state completes alignment:
687+ Advance the child_guide state
688+ """
689+ if state .legal_path_map is None :
690+ if self .child_guide is not None :
691+ return AlignmentGuideState (
692+ legal_path_map = None ,
693+ child_guide_state = self .child_guide .get_next_state (
694+ state .child_guide_state , token_id
695+ ),
696+ )
697+ else :
698+ return AlignmentGuideState (None , None )
699+ else :
700+ next_state_legal_path_map = {
701+ key [1 :]: value
702+ for key , value in state .legal_path_map .items ()
703+ if key [0 ] == token_id
704+ }
705+ # if none remaining, advance the child guide
706+ if not any (next_state_legal_path_map ):
707+ if self .child_guide is not None :
708+ child_guide_advancement_ids = next (
709+ iter (next_state_legal_path_map .values ())
710+ )
711+ return AlignmentGuideState (
712+ legal_path_map = None ,
713+ child_guide_state = self .child_guide .derive (
714+ child_guide_advancement_ids , state .child_guide_state
715+ ),
716+ )
717+ else :
718+ return AlignmentGuideState (None , None )
719+
720+ # if paths remaining, return advanced legal_path_map
721+ else :
722+ return AlignmentGuideState (
723+ legal_path_map = next_state_legal_path_map ,
724+ child_guide_state = state .child_guide_state ,
725+ )
726+
727+ def is_final_state (self , state : AlignmentGuideState ) -> bool :
728+ if state .legal_path_map is not None :
729+ return False
730+ elif self .child_guide is None :
731+ return True
732+ else :
733+ return self .child_guide .is_final_state (state .child_guide_state )
734+
735+ def copy (self ):
736+ """AlignmentGuide isn't mutated"""
737+ return self
0 commit comments