1+ from collections import defaultdict
2+ from copy import deepcopy
13from dataclasses import dataclass
2- from typing import TYPE_CHECKING , List , Optional , Protocol , Tuple , Union
4+ from typing import TYPE_CHECKING , Dict , List , Optional , Protocol , Tuple , Union
35
46import interegular
7+ import torch
58from lark import Lark
69
710from outlines import grammars
@@ -67,14 +70,17 @@ def get_next_state(self, state: int, token_id: int) -> int:
6770 def is_final_state (self , state : int ) -> bool :
6871 ...
6972
73+ def align_prompt_tokens (self , prompt : str ) -> str :
74+ ...
75+
7076 def copy (self ) -> "Guide" :
7177 ...
7278
7379
7480class StopAtEOSGuide (Guide ):
7581 """Guide to generate tokens until the EOS token has been generated."""
7682
77- final_state = 1
83+ final_state = - 1
7884 start_state = 0
7985
8086 def __init__ (self , tokenizer : "Tokenizer" ):
@@ -85,24 +91,49 @@ def __init__(self, tokenizer: "Tokenizer"):
8591
8692 """
8793 self .eos_token_id = tokenizer .eos_token_id
88- self .vocabulary = tokenizer .vocabulary .values ()
94+ self .tokenizer = tokenizer
95+ self .states_to_token_maps = self .create_states_to_tokens_map ()
96+
97+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
98+ """Create the states_to_tokens_map. All tokens lead to the starting
99+ state, except for the eos_token that leads to the final state."""
100+ return {
101+ self .start_state : {
102+ token_id : self .start_state
103+ if token_id != self .eos_token_id
104+ else self .final_state
105+ for token_id in self .tokenizer .vocabulary .values ()
106+ }
107+ }
108+
109+ def align_prompt_tokens (self , prompt : str ) -> str :
110+ """Update the states_to_token_maps and return the aligned prompt"""
111+ token_ids , _ = self .tokenizer .encode (prompt )
112+ (
113+ aligned_token_ids ,
114+ self .states_to_token_maps ,
115+ ) = align_tokens_states_to_token_maps (
116+ token_ids [0 ], self .tokenizer .vocabulary , self .states_to_token_maps
117+ )
118+ decoded_aligned_token_ids = self .tokenizer .decode (aligned_token_ids )
119+ return "" .join (decoded_aligned_token_ids )
89120
90121 def get_next_instruction (self , state : int ) -> Instruction :
91122 if self .is_final_state (state ):
92123 return Write ([self .eos_token_id ])
93- return Generate (None )
124+ return Generate (list ( self . states_to_token_maps [ state ]. keys ()) )
94125
95126 def get_next_state (self , state : int , token_id : int ) -> int :
96- if token_id == self .eos_token_id or state == self . final_state :
127+ if self .is_final_state ( state ) :
97128 return self .final_state
98129
99- return self .start_state
130+ return self .states_to_token_maps [ state ][ token_id ]
100131
101132 def is_final_state (self , state : int ):
102133 return state == self .final_state
103134
104135 def copy (self ):
105- return self
136+ return deepcopy ( self )
106137
107138
108139@cache ()
@@ -143,9 +174,22 @@ def __init__(self, regex_string: str, tokenizer):
143174 self .empty_token_ids ,
144175 fsm_finals ,
145176 ) = create_states_mapping (regex_string , tokenizer )
177+ self .tokenizer = tokenizer
146178 self .eos_token_id = tokenizer .eos_token_id
147179 self .final_states = fsm_finals | {- 1 }
148180
181+ def align_prompt_tokens (self , prompt : str ) -> str :
182+ """Update the states_to_token_maps and return the aligned prompt"""
183+ token_ids , _ = self .tokenizer .encode (prompt )
184+ (
185+ aligned_token_ids ,
186+ self .states_to_token_maps ,
187+ ) = align_tokens_states_to_token_maps (
188+ token_ids [0 ], self .tokenizer .vocabulary , self .states_to_token_maps
189+ )
190+ decoded_aligned_token_ids = self .tokenizer .decode (aligned_token_ids )
191+ return "" .join (decoded_aligned_token_ids )
192+
149193 def get_next_instruction (self , state : int ) -> Instruction :
150194 """Return the next instruction for guided generation.
151195
@@ -246,7 +290,7 @@ def is_final_state(self, state: int) -> bool:
246290 return state in self .final_states
247291
248292 def copy (self ):
249- return self
293+ return deepcopy ( self )
250294
251295
252296class CFGGuide (Guide ):
@@ -283,6 +327,10 @@ def __init__(self, cfg_string: str, tokenizer):
283327 self .start_state = 0
284328 self .final_state = - 1
285329
330+ def align_prompt_tokens (self , prompt : str ) -> str :
331+ """Not applicable to this type of Guide"""
332+ return prompt
333+
286334 def get_next_instruction (self , state : int ) -> Instruction :
287335 """Generate an instruction for the next step.
288336
@@ -424,3 +472,161 @@ def is_final_state(self, state: int) -> bool:
424472 def copy (self ) -> "CFGGuide" :
425473 """Create a copy of the FSM."""
426474 return CFGGuide (self .cfg_string , self .tokenizer )
475+
476+
477+ def align_tokens_states_to_token_maps (
478+ token_ids : torch .Tensor ,
479+ vocabulary : Dict [str , int ],
480+ states_to_token_maps : Dict [int , Dict [int , int ]],
481+ ) -> Tuple [torch .Tensor , Dict [int , Dict [int , int ]]]:
482+ """Apply token alignment to the provided prompt tokens and attention masks given the
483+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
484+ states_to_token_maps"""
485+ prompt_token_ids = token_ids .tolist ()
486+ crossing_tokens = find_crossing_tokens (prompt_token_ids , vocabulary )
487+ valid_crossing_tokens = get_crossing_tokens_target_states (
488+ states_to_token_maps , crossing_tokens , prompt_token_ids , vocabulary
489+ )
490+ if not valid_crossing_tokens :
491+ return token_ids , states_to_token_maps
492+ (
493+ states_to_token_maps ,
494+ number_cropped_tokens ,
495+ ) = add_crossing_tokens_states_to_tokens_map (
496+ states_to_token_maps , prompt_token_ids , valid_crossing_tokens
497+ )
498+ return (
499+ token_ids [:- number_cropped_tokens ],
500+ states_to_token_maps ,
501+ )
502+
503+
504+ def find_crossing_tokens (
505+ token_ids : List [int ], vocabulary : Dict [str , int ]
506+ ) -> Dict [int , List [int ]]:
507+ """Find the tokens that could replace one or more tokens at the end of token_ids
508+ while conserving the same intial text (and extending it by at least one character).
509+ Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
510+ """
511+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
512+ len_token_ids = len (token_ids )
513+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
514+ characters_considered = ""
515+ crossing_tokens_map = {}
516+
517+ for index , token_id in enumerate (reversed (token_ids )):
518+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
519+ if len (characters_considered ) >= max_length_token_text :
520+ break
521+ crossing_token_ids = [
522+ token_id
523+ for text , token_id in vocabulary .items ()
524+ if text .startswith (characters_considered )
525+ and len (text ) > len (characters_considered )
526+ ]
527+ if crossing_token_ids :
528+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
529+
530+ return crossing_tokens_map
531+
532+
533+ def get_crossing_tokens_target_states (
534+ states_to_tokens_map : Dict [int , Dict [int , int ]],
535+ crossing_tokens : Dict [int , List [int ]],
536+ prompt_token_ids : List [int ],
537+ vocabulary : Dict [str , int ],
538+ ) -> Dict [int , Dict [int , int ]]:
539+ """For each crossing token associated to an index, check that the characters after the boundary
540+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
541+ provided indexes, the associated valid tokens with the state they would lead to.
542+ """
543+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
544+ prompt_token_texts = [
545+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
546+ ]
547+
548+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
549+ for pos , tokens in crossing_tokens .items ():
550+ for token in tokens :
551+ is_valid = True
552+ characters = reversed_vocabulary [token ]
553+ characters_before_border = "" .join (prompt_token_texts [pos :])
554+ characters_after_border = characters [len (characters_before_border ) :]
555+ state = 0
556+ for char in characters_after_border :
557+ char_token = vocabulary .get (char )
558+ try :
559+ state = states_to_tokens_map [state ][char_token ] # type: ignore
560+ except KeyError :
561+ is_valid = False
562+ break
563+ if is_valid :
564+ valid_crossing_tokens [pos ][token ] = state
565+
566+ return valid_crossing_tokens
567+
568+
569+ def add_crossing_tokens_states_to_tokens_map (
570+ states_to_tokens_map : Dict [int , Dict [int , int ]],
571+ prompt_token_ids : List [int ],
572+ crossing_tokens_map : Dict [int , Dict [int , int ]],
573+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
574+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
575+ the starting state of the fsm as we would include some characters at the end of the prompt in
576+ the states_to_tokens_map.
577+ Attention! the starting state of the states_to_tokens_map provided must be 0.
578+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
579+ """
580+ if not crossing_tokens_map :
581+ return states_to_tokens_map , 0
582+ first_crossing_token_pos = min (
583+ [key for key , value in crossing_tokens_map .items () if value ]
584+ )
585+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
586+ highest_state = max (
587+ max (states_to_tokens_map .keys ()),
588+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
589+ )
590+
591+ for i in range (number_additional_states ):
592+ # add the tokens that was originally part of the prompt
593+ if i == number_additional_states - 1 :
594+ states_to_tokens_map [highest_state + 1 + i ] = {
595+ prompt_token_ids [first_crossing_token_pos + i ]: 0
596+ }
597+ else :
598+ states_to_tokens_map [highest_state + 1 + i ] = {
599+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
600+ }
601+ # add the crossing tokens
602+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
603+ if crossing_tokens :
604+ for token , target_state in crossing_tokens .items ():
605+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
606+
607+ # set the id of our new initial state to 0
608+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
609+ states_to_tokens_map , highest_state + 1 , 0
610+ )
611+ return states_to_tokens_map , number_additional_states
612+
613+
614+ def swap_state_ids_states_to_tokens_map (
615+ states_to_tokens_map : Dict [int , Dict [int , int ]],
616+ first_state_id : int ,
617+ second_state_id : int ,
618+ ) -> Dict [int , Dict [int , int ]]:
619+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
620+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
621+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
622+ states_to_tokens_map [first_state_id ] = second_state_transitions
623+ states_to_tokens_map [second_state_id ] = first_state_transitions
624+
625+ for transitions in states_to_tokens_map .values ():
626+ for token , target_state_id in list (transitions .items ()):
627+ if target_state_id == first_state_id :
628+ transitions [token ] = second_state_id
629+ elif target_state_id == second_state_id :
630+ transitions [token ] = first_state_id
631+
632+ return states_to_tokens_map
0 commit comments