1+ from collections import defaultdict
2+ from copy import copy , deepcopy
13from dataclasses import dataclass
24from typing import (
35 TYPE_CHECKING ,
@@ -69,6 +71,9 @@ class Guide(Protocol):
6971
7072 """
7173
74+ start_state : int = 0
75+ final_state : int = - 1
76+
7277 def get_next_instruction (self , state : int ) -> Instruction :
7378 ...
7479
@@ -82,11 +87,39 @@ def copy(self) -> "Guide":
8287 ...
8388
8489
85- class StopAtEOSGuide ( Guide ) :
86- """Guide to generate tokens until the EOS token has been generated. """
90+ class TokenHealerMixin :
91+ """Class used to add the token align feature to a Guide """
8792
88- final_state = 1
89- start_state = 0
93+ states_to_token_maps : Dict [int , Dict [int , int ]]
94+ tokenizer : "Tokenizer"
95+
96+ def align_prompt_tokens (self , prompt : str ) -> str :
97+ """Update the states_to_token_maps and return the aligned prompt"""
98+ token_ids , _ = self .tokenizer .encode (prompt )
99+ (
100+ aligned_token_ids ,
101+ aligned_states_to_token_maps ,
102+ ) = align_tokens_states_to_token_maps (
103+ token_ids .tolist ()[0 ],
104+ self .tokenizer .vocabulary ,
105+ deepcopy (self .states_to_token_maps ),
106+ )
107+ aligned_prompt = self .tokenizer .decode ([aligned_token_ids ])[0 ]
108+ # some models do not accept an empty string as a prompt
109+ # if token alignement would remove all tokens, do not apply it
110+ if not aligned_prompt :
111+ return prompt
112+ self .states_to_token_maps = aligned_states_to_token_maps
113+ if hasattr (self , "_cache_state_to_token_tensor" ):
114+ self ._cache_state_to_token_tensor ()
115+ # remove leading whitespace if added by the tokenizer
116+ if aligned_prompt [0 ] == " " and prompt [0 ] != " " :
117+ return aligned_prompt [1 :]
118+ return aligned_prompt
119+
120+
121+ class StopAtEOSGuide (Guide , TokenHealerMixin ):
122+ """Guide to generate tokens until the EOS token has been generated."""
90123
91124 def __init__ (self , tokenizer : "Tokenizer" ):
92125 """Initialize the generation guide.
@@ -95,25 +128,37 @@ def __init__(self, tokenizer: "Tokenizer"):
95128 The logit generator used to generate the next token.
96129
97130 """
98- self .eos_token_id = tokenizer .eos_token_id
99- self .vocabulary = tokenizer .vocabulary .values ()
131+ self .tokenizer = tokenizer
132+ self .states_to_token_maps = self .create_states_to_tokens_map ()
133+
134+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
135+ """Create the states_to_tokens_map. All tokens lead to the starting
136+ state, except for the eos_token that leads to the final state."""
137+ return {
138+ self .start_state : {
139+ token_id : self .start_state
140+ if token_id != self .tokenizer .eos_token_id
141+ else self .final_state
142+ for token_id in self .tokenizer .vocabulary .values ()
143+ }
144+ }
100145
101146 def get_next_instruction (self , state : int ) -> Instruction :
102147 if self .is_final_state (state ):
103- return Write ([self .eos_token_id ])
104- return Generate (None )
148+ return Write ([self .tokenizer . eos_token_id ])
149+ return Generate (list ( self . states_to_token_maps [ state ]. keys ()) )
105150
106151 def get_next_state (self , state : int , token_id : int ) -> int :
107- if token_id == self .eos_token_id or state == self . final_state :
152+ if self .is_final_state ( state ) :
108153 return self .final_state
109154
110- return self .start_state
155+ return self .states_to_token_maps [ state ][ token_id ]
111156
112157 def is_final_state (self , state : int ):
113158 return state == self .final_state
114159
115160 def copy (self ):
116- return self
161+ return copy ( self )
117162
118163
119164@cache ()
@@ -171,20 +216,20 @@ def create_states_mapping(
171216 return states_to_token_maps , empty_token_ids , regex_fsm .finals
172217
173218
174- class RegexGuide (Guide ):
219+ class RegexGuide (Guide , TokenHealerMixin ):
175220 """Guide to generate text in the language of a regular expression."""
176221
177- initial_state = 0
222+ states_to_token_mask : Dict [ int , torch . Tensor ]
178223
179224 def __init__ (self , regex_string : str , tokenizer : "Tokenizer" ):
225+ self .tokenizer = tokenizer
180226 (
181227 self .states_to_token_maps ,
182228 self .empty_token_ids ,
183229 fsm_finals ,
184230 ) = create_states_mapping (regex_string , tokenizer )
185- self .eos_token_id = tokenizer .eos_token_id
186- self .final_states = fsm_finals | {- 1 }
187231 self ._cache_state_to_token_tensor ()
232+ self .final_states = fsm_finals | {self .final_state }
188233
189234 def get_next_instruction (self , state : int ) -> Instruction :
190235 """Return the next instruction for guided generation.
@@ -211,7 +256,7 @@ def get_next_instruction(self, state: int) -> Instruction:
211256 """
212257 next_tokens_mask = self .states_to_token_mask .get (state )
213258 if next_tokens_mask is None :
214- return Write (torch .tensor ([self .eos_token_id ]))
259+ return Write (torch .tensor ([self .tokenizer . eos_token_id ]))
215260
216261 return Generate (next_tokens_mask )
217262
@@ -233,13 +278,16 @@ def get_next_state(self, state: int, token_id: int) -> int:
233278 The new state of the guide.
234279
235280 """
236- if token_id == self .eos_token_id or state not in self .states_to_token_maps :
237- return - 1
281+ if (
282+ token_id == self .tokenizer .eos_token_id
283+ or state not in self .states_to_token_maps
284+ ):
285+ return self .final_state
238286
239287 last_token_to_end_state = self .states_to_token_maps [state ]
240288 next_state = last_token_to_end_state .get (token_id )
241289 if next_state is None :
242- next_state = - 1
290+ next_state = self . final_state
243291
244292 return next_state
245293
@@ -278,11 +326,11 @@ def create_states_mapping_from_interegular_fsm(
278326 from_interegular_instance .states_to_token_maps ,
279327 from_interegular_instance .empty_token_ids ,
280328 ) = create_states_mapping_from_interegular_fsm (interegular_fsm )
281- from_interegular_instance .eos_token_id = tokenizer . eos_token_id
329+ from_interegular_instance .tokenizer = tokenizer
282330 from_interegular_instance ._cache_state_to_token_tensor ()
283331 return from_interegular_instance
284332
285- def _cache_state_to_token_tensor (self ):
333+ def _cache_state_to_token_tensor (self ) -> None :
286334 """
287335 cache state -> token int tensor
288336 this increases performance of mask construction substantially
@@ -297,7 +345,7 @@ def is_final_state(self, state: int) -> bool:
297345 return state in self .final_states
298346
299347 def copy (self ):
300- return self
348+ return copy ( self )
301349
302350
303351class CFGGuide (Guide ):
@@ -331,9 +379,6 @@ def __init__(self, cfg_string: str, tokenizer):
331379 self .proposal_last : List [int ] = []
332380 self .regex_fsm_last : RegexGuide
333381
334- self .start_state = 0
335- self .final_state = - 1
336-
337382 def get_next_instruction (self , state : int ) -> Instruction :
338383 """Generate an instruction for the next step.
339384
@@ -475,3 +520,163 @@ def is_final_state(self, state: int) -> bool:
475520 def copy (self ) -> "CFGGuide" :
476521 """Create a copy of the FSM."""
477522 return CFGGuide (self .cfg_string , self .tokenizer )
523+
524+
525+ def align_tokens_states_to_token_maps (
526+ token_ids : List [int ],
527+ vocabulary : Dict [str , int ],
528+ states_to_token_maps : Dict [int , Dict [int , int ]],
529+ ) -> Tuple [List [int ], Dict [int , Dict [int , int ]]]:
530+ """Apply token alignment to the provided prompt tokens and attention masks given the
531+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
532+ states_to_token_maps. You can find an explanation from Guidance on why token healing
533+ is necessary here:
534+ https://github.com/guidance-ai/guidance/blob/main/notebooks/tutorials/token_healing.ipynb
535+ """
536+ crossing_tokens = find_crossing_tokens (token_ids , vocabulary )
537+ valid_crossing_tokens = get_crossing_tokens_target_states (
538+ states_to_token_maps , crossing_tokens , token_ids , vocabulary
539+ )
540+ if not valid_crossing_tokens :
541+ return token_ids , states_to_token_maps
542+ (
543+ states_to_token_maps ,
544+ number_cropped_tokens ,
545+ ) = add_crossing_tokens_states_to_tokens_map (
546+ states_to_token_maps , token_ids , valid_crossing_tokens
547+ )
548+ return (
549+ token_ids [:- number_cropped_tokens ],
550+ states_to_token_maps ,
551+ )
552+
553+
554+ def find_crossing_tokens (
555+ token_ids : List [int ], vocabulary : Dict [str , int ]
556+ ) -> Dict [int , List [int ]]:
557+ """Find the tokens that could replace one or more tokens at the end of token_ids
558+ while conserving the same intial text (and extending it by at least one character).
559+ Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
560+ """
561+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
562+ len_token_ids = len (token_ids )
563+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
564+ characters_considered = ""
565+ crossing_tokens_map = {}
566+
567+ for index , token_id in enumerate (reversed (token_ids )):
568+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
569+ if len (characters_considered ) >= max_length_token_text :
570+ break
571+ crossing_token_ids = [
572+ token_id
573+ for text , token_id in vocabulary .items ()
574+ if text .startswith (characters_considered )
575+ and len (text ) > len (characters_considered )
576+ ]
577+ if crossing_token_ids :
578+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
579+
580+ return crossing_tokens_map
581+
582+
583+ def get_crossing_tokens_target_states (
584+ states_to_tokens_map : Dict [int , Dict [int , int ]],
585+ crossing_tokens : Dict [int , List [int ]],
586+ prompt_token_ids : List [int ],
587+ vocabulary : Dict [str , int ],
588+ ) -> Dict [int , Dict [int , int ]]:
589+ """For each crossing token associated to an index, check that the characters after the boundary
590+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
591+ provided indexes, the associated valid tokens with the state they would lead to.
592+ """
593+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
594+ prompt_token_texts = [
595+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
596+ ]
597+
598+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
599+ for pos , tokens in crossing_tokens .items ():
600+ for token in tokens :
601+ is_valid = True
602+ characters = reversed_vocabulary [token ]
603+ characters_before_border = "" .join (prompt_token_texts [pos :])
604+ characters_after_border = characters [len (characters_before_border ) :]
605+ state = 0
606+ for char in characters_after_border :
607+ char_token = vocabulary .get (char )
608+ try :
609+ state = states_to_tokens_map [state ][char_token ] # type: ignore
610+ except KeyError :
611+ is_valid = False
612+ break
613+ if is_valid :
614+ valid_crossing_tokens [pos ][token ] = state
615+
616+ return valid_crossing_tokens
617+
618+
619+ def add_crossing_tokens_states_to_tokens_map (
620+ states_to_tokens_map : Dict [int , Dict [int , int ]],
621+ prompt_token_ids : List [int ],
622+ crossing_tokens_map : Dict [int , Dict [int , int ]],
623+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
624+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
625+ the starting state of the fsm as we would include some characters at the end of the prompt in
626+ the states_to_tokens_map.
627+ Attention! the starting state of the states_to_tokens_map provided must be 0.
628+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
629+ """
630+ if not crossing_tokens_map :
631+ return states_to_tokens_map , 0
632+ first_crossing_token_pos = min (
633+ [key for key , value in crossing_tokens_map .items () if value ]
634+ )
635+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
636+ highest_state = max (
637+ max (states_to_tokens_map .keys ()),
638+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
639+ )
640+
641+ for i in range (number_additional_states ):
642+ # add the tokens that was originally part of the prompt
643+ if i == number_additional_states - 1 :
644+ states_to_tokens_map [highest_state + 1 + i ] = {
645+ prompt_token_ids [first_crossing_token_pos + i ]: 0
646+ }
647+ else :
648+ states_to_tokens_map [highest_state + 1 + i ] = {
649+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
650+ }
651+ # add the crossing tokens
652+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
653+ if crossing_tokens :
654+ for token , target_state in crossing_tokens .items ():
655+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
656+
657+ # set the id of our new initial state to 0
658+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
659+ states_to_tokens_map , highest_state + 1 , 0
660+ )
661+ return states_to_tokens_map , number_additional_states
662+
663+
664+ def swap_state_ids_states_to_tokens_map (
665+ states_to_tokens_map : Dict [int , Dict [int , int ]],
666+ first_state_id : int ,
667+ second_state_id : int ,
668+ ) -> Dict [int , Dict [int , int ]]:
669+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
670+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
671+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
672+ states_to_tokens_map [first_state_id ] = second_state_transitions
673+ states_to_tokens_map [second_state_id ] = first_state_transitions
674+
675+ for transitions in states_to_tokens_map .values ():
676+ for token , target_state_id in list (transitions .items ()):
677+ if target_state_id == first_state_id :
678+ transitions [token ] = second_state_id
679+ elif target_state_id == second_state_id :
680+ transitions [token ] = first_state_id
681+
682+ return states_to_tokens_map
0 commit comments