1+ from collections import defaultdict
2+ from copy import deepcopy
13from dataclasses import dataclass
24from typing import (
35 TYPE_CHECKING ,
@@ -78,14 +80,17 @@ def get_next_state(self, state: int, token_id: int) -> int:
7880 def is_final_state (self , state : int ) -> bool :
7981 ...
8082
83+ def align_prompt_tokens (self , prompt : str , tokenizer : "Tokenizer" ) -> str :
84+ ...
85+
8186 def copy (self ) -> "Guide" :
8287 ...
8388
8489
8590class StopAtEOSGuide (Guide ):
8691 """Guide to generate tokens until the EOS token has been generated."""
8792
88- final_state = 1
93+ final_state = - 1
8994 start_state = 0
9095
9196 def __init__ (self , tokenizer : "Tokenizer" ):
@@ -96,24 +101,69 @@ def __init__(self, tokenizer: "Tokenizer"):
96101
97102 """
98103 self .eos_token_id = tokenizer .eos_token_id
99- self .vocabulary = tokenizer .vocabulary .values ()
104+ self .vocabulary = tokenizer .vocabulary
105+ self .states_to_token_maps = self .create_states_to_tokens_map ()
106+
107+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
108+ """Create the states_to_tokens_map. All tokens lead to the starting
109+ state, except for the eos_token that leads to the final state."""
110+ return {
111+ self .start_state : {
112+ token_id : self .start_state
113+ if token_id != self .eos_token_id
114+ else self .final_state
115+ for token_id in self .vocabulary .values ()
116+ }
117+ }
118+
119+ def align_prompt_tokens (self , prompt : str , tokenizer : "Tokenizer" ) -> str :
120+ """Update the states_to_token_maps and return the aligned prompt"""
121+ token_ids , _ = tokenizer .encode (prompt )
122+ # possible return types of tokenizers include list, 1d Tensor and 2d Tensor
123+ if not isinstance (token_ids , list ):
124+ token_ids = token_ids .tolist ()
125+ if isinstance (token_ids [0 ], list ):
126+ token_ids = token_ids [0 ]
127+ (
128+ aligned_token_ids ,
129+ aligned_states_to_token_maps ,
130+ ) = align_tokens_states_to_token_maps (
131+ token_ids , self .vocabulary , deepcopy (self .states_to_token_maps )
132+ )
133+ # some tokenizer expect a list of lists while others expect a simple list
134+ aligned_prompt : list
135+ try :
136+ aligned_prompt = tokenizer .decode ([aligned_token_ids ])
137+ except TypeError :
138+ aligned_prompt = tokenizer .decode (aligned_token_ids )
139+ # some models do not accept an empty string as a prompt
140+ # if token alignement would remove all tokens, do not apply it
141+ if not aligned_prompt or not aligned_prompt [0 ]:
142+ return prompt
143+ aligned_prompt = aligned_prompt [0 ]
144+ print (prompt , token_ids , aligned_token_ids , aligned_prompt )
145+ self .states_to_token_maps = aligned_states_to_token_maps
146+ # remove leading whitespace if added by the tokenizer
147+ if aligned_prompt [0 ] == " " and prompt [0 ] != " " :
148+ return aligned_prompt [1 :]
149+ return aligned_prompt
100150
101151 def get_next_instruction (self , state : int ) -> Instruction :
102152 if self .is_final_state (state ):
103153 return Write ([self .eos_token_id ])
104- return Generate (None )
154+ return Generate (list ( self . states_to_token_maps [ state ]. keys ()) )
105155
106156 def get_next_state (self , state : int , token_id : int ) -> int :
107- if token_id == self .eos_token_id or state == self . final_state :
157+ if self .is_final_state ( state ) :
108158 return self .final_state
109159
110- return self .start_state
160+ return self .states_to_token_maps [ state ][ token_id ]
111161
112162 def is_final_state (self , state : int ):
113163 return state == self .final_state
114164
115165 def copy (self ):
116- return self
166+ return deepcopy ( self )
117167
118168
119169@cache ()
@@ -182,9 +232,41 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
182232 self .empty_token_ids ,
183233 fsm_finals ,
184234 ) = create_states_mapping (regex_string , tokenizer )
235+ self ._cache_state_to_token_tensor ()
236+ self .vocabulary = tokenizer .vocabulary
185237 self .eos_token_id = tokenizer .eos_token_id
186238 self .final_states = fsm_finals | {- 1 }
239+
240+ def align_prompt_tokens (self , prompt : str , tokenizer : "Tokenizer" ) -> str :
241+ """Update the states_to_token_maps and return the aligned prompt"""
242+ token_ids , _ = tokenizer .encode (prompt )
243+ # possible return types of tokenizers include list, 1d Tensor and 2d Tensor
244+ if not isinstance (token_ids , list ):
245+ token_ids = token_ids .tolist ()
246+ if isinstance (token_ids [0 ], list ):
247+ token_ids = token_ids [0 ]
248+ (
249+ aligned_token_ids ,
250+ aligned_states_to_token_maps ,
251+ ) = align_tokens_states_to_token_maps (
252+ token_ids , self .vocabulary , deepcopy (self .states_to_token_maps )
253+ )
254+ # some tokenizer expect a list of lists while others expect a simple list
255+ aligned_prompt : str
256+ try :
257+ aligned_prompt = tokenizer .decode ([aligned_token_ids ])[0 ]
258+ except TypeError :
259+ aligned_prompt = tokenizer .decode (aligned_token_ids )[0 ]
260+ # some models do not accept an empty string as a prompt
261+ # if token alignement would remove all tokens, do not apply it
262+ if not aligned_prompt :
263+ return prompt
264+ self .states_to_token_maps = aligned_states_to_token_maps
187265 self ._cache_state_to_token_tensor ()
266+ # remove leading whitespace if added by the tokenizer
267+ if aligned_prompt [0 ] == " " and prompt [0 ] != " " :
268+ return aligned_prompt [1 :]
269+ return aligned_prompt
188270
189271 def get_next_instruction (self , state : int ) -> Instruction :
190272 """Return the next instruction for guided generation.
@@ -278,6 +360,7 @@ def create_states_mapping_from_interegular_fsm(
278360 from_interegular_instance .states_to_token_maps ,
279361 from_interegular_instance .empty_token_ids ,
280362 ) = create_states_mapping_from_interegular_fsm (interegular_fsm )
363+ from_interegular_instance .vocabulary = tokenizer .vocabulary
281364 from_interegular_instance .eos_token_id = tokenizer .eos_token_id
282365 from_interegular_instance ._cache_state_to_token_tensor ()
283366 return from_interegular_instance
@@ -297,7 +380,7 @@ def is_final_state(self, state: int) -> bool:
297380 return state in self .final_states
298381
299382 def copy (self ):
300- return self
383+ return deepcopy ( self )
301384
302385
303386class CFGGuide (Guide ):
@@ -334,6 +417,10 @@ def __init__(self, cfg_string: str, tokenizer):
334417 self .start_state = 0
335418 self .final_state = - 1
336419
420+ def align_prompt_tokens (self , prompt : str , tokenizer : "Tokenizer" ) -> str :
421+ """Not applicable to this type of Guide"""
422+ return prompt
423+
337424 def get_next_instruction (self , state : int ) -> Instruction :
338425 """Generate an instruction for the next step.
339426
@@ -475,3 +562,160 @@ def is_final_state(self, state: int) -> bool:
475562 def copy (self ) -> "CFGGuide" :
476563 """Create a copy of the FSM."""
477564 return CFGGuide (self .cfg_string , self .tokenizer )
565+
566+
567+ def align_tokens_states_to_token_maps (
568+ token_ids : List [int ],
569+ vocabulary : Dict [str , int ],
570+ states_to_token_maps : Dict [int , Dict [int , int ]],
571+ ) -> Tuple [List [int ], Dict [int , Dict [int , int ]]]:
572+ """Apply token alignment to the provided prompt tokens and attention masks given the
573+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
574+ states_to_token_maps"""
575+ crossing_tokens = find_crossing_tokens (token_ids , vocabulary )
576+ valid_crossing_tokens = get_crossing_tokens_target_states (
577+ states_to_token_maps , crossing_tokens , token_ids , vocabulary
578+ )
579+ if not valid_crossing_tokens :
580+ return token_ids , states_to_token_maps
581+ (
582+ states_to_token_maps ,
583+ number_cropped_tokens ,
584+ ) = add_crossing_tokens_states_to_tokens_map (
585+ states_to_token_maps , token_ids , valid_crossing_tokens
586+ )
587+ return (
588+ token_ids [:- number_cropped_tokens ],
589+ states_to_token_maps ,
590+ )
591+
592+
593+ def find_crossing_tokens (
594+ token_ids : List [int ], vocabulary : Dict [str , int ]
595+ ) -> Dict [int , List [int ]]:
596+ """Find the tokens that could replace one or more tokens at the end of token_ids
597+ while conserving the same intial text (and extending it by at least one character).
598+ Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
599+ """
600+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
601+ len_token_ids = len (token_ids )
602+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
603+ characters_considered = ""
604+ crossing_tokens_map = {}
605+
606+ for index , token_id in enumerate (reversed (token_ids )):
607+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
608+ if len (characters_considered ) >= max_length_token_text :
609+ break
610+ crossing_token_ids = [
611+ token_id
612+ for text , token_id in vocabulary .items ()
613+ if text .startswith (characters_considered )
614+ and len (text ) > len (characters_considered )
615+ ]
616+ if crossing_token_ids :
617+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
618+
619+ return crossing_tokens_map
620+
621+
622+ def get_crossing_tokens_target_states (
623+ states_to_tokens_map : Dict [int , Dict [int , int ]],
624+ crossing_tokens : Dict [int , List [int ]],
625+ prompt_token_ids : List [int ],
626+ vocabulary : Dict [str , int ],
627+ ) -> Dict [int , Dict [int , int ]]:
628+ """For each crossing token associated to an index, check that the characters after the boundary
629+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
630+ provided indexes, the associated valid tokens with the state they would lead to.
631+ """
632+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
633+ prompt_token_texts = [
634+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
635+ ]
636+
637+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
638+ for pos , tokens in crossing_tokens .items ():
639+ for token in tokens :
640+ is_valid = True
641+ characters = reversed_vocabulary [token ]
642+ characters_before_border = "" .join (prompt_token_texts [pos :])
643+ characters_after_border = characters [len (characters_before_border ) :]
644+ state = 0
645+ for char in characters_after_border :
646+ char_token = vocabulary .get (char )
647+ try :
648+ state = states_to_tokens_map [state ][char_token ] # type: ignore
649+ except KeyError :
650+ is_valid = False
651+ break
652+ if is_valid :
653+ valid_crossing_tokens [pos ][token ] = state
654+
655+ return valid_crossing_tokens
656+
657+
658+ def add_crossing_tokens_states_to_tokens_map (
659+ states_to_tokens_map : Dict [int , Dict [int , int ]],
660+ prompt_token_ids : List [int ],
661+ crossing_tokens_map : Dict [int , Dict [int , int ]],
662+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
663+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
664+ the starting state of the fsm as we would include some characters at the end of the prompt in
665+ the states_to_tokens_map.
666+ Attention! the starting state of the states_to_tokens_map provided must be 0.
667+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
668+ """
669+ if not crossing_tokens_map :
670+ return states_to_tokens_map , 0
671+ first_crossing_token_pos = min (
672+ [key for key , value in crossing_tokens_map .items () if value ]
673+ )
674+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
675+ highest_state = max (
676+ max (states_to_tokens_map .keys ()),
677+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
678+ )
679+
680+ for i in range (number_additional_states ):
681+ # add the tokens that was originally part of the prompt
682+ if i == number_additional_states - 1 :
683+ states_to_tokens_map [highest_state + 1 + i ] = {
684+ prompt_token_ids [first_crossing_token_pos + i ]: 0
685+ }
686+ else :
687+ states_to_tokens_map [highest_state + 1 + i ] = {
688+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
689+ }
690+ # add the crossing tokens
691+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
692+ if crossing_tokens :
693+ for token , target_state in crossing_tokens .items ():
694+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
695+
696+ # set the id of our new initial state to 0
697+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
698+ states_to_tokens_map , highest_state + 1 , 0
699+ )
700+ return states_to_tokens_map , number_additional_states
701+
702+
703+ def swap_state_ids_states_to_tokens_map (
704+ states_to_tokens_map : Dict [int , Dict [int , int ]],
705+ first_state_id : int ,
706+ second_state_id : int ,
707+ ) -> Dict [int , Dict [int , int ]]:
708+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
709+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
710+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
711+ states_to_tokens_map [first_state_id ] = second_state_transitions
712+ states_to_tokens_map [second_state_id ] = first_state_transitions
713+
714+ for transitions in states_to_tokens_map .values ():
715+ for token , target_state_id in list (transitions .items ()):
716+ if target_state_id == first_state_id :
717+ transitions [token ] = second_state_id
718+ elif target_state_id == second_state_id :
719+ transitions [token ] = first_state_id
720+
721+ return states_to_tokens_map
0 commit comments