1- from typing import TYPE_CHECKING , List , NewType , Protocol , Tuple
1+ from collections import defaultdict
2+ from copy import deepcopy
3+ from typing import TYPE_CHECKING , Dict , List , NewType , Protocol , Tuple
24
35import interegular
6+ import torch
47from lark import Lark
58
69# from outlines.fsm.parsing import PartialLark
@@ -22,6 +25,11 @@ def is_final_state(self, state: FSMState) -> bool:
2225 """Determine whether the current state of the FSM is a final state."""
2326 return state == self .final_state
2427
28+ def align_prompt_tokens (
29+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
30+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
31+ ...
32+
2533 def allowed_token_ids (self , state : FSMState ) -> List [int ]:
2634 ...
2735
@@ -37,13 +45,41 @@ class StopAtEosFSM(FSM):
3745
3846 def __init__ (self , tokenizer : "Tokenizer" ):
3947 self .eos_token_id = tokenizer .eos_token_id
40- self .vocabulary = tokenizer .vocabulary .values ()
48+ self .vocabulary = tokenizer .vocabulary
49+ self .tokenizer = tokenizer
50+ self .states_to_token_maps = self .create_states_to_tokens_map ()
51+
52+ def create_states_to_tokens_map (self ) -> Dict [int , Dict [int , int ]]:
53+ """Create the states_to_tokens_map. All tokens from the starting state lead
54+ to itself, except for the eos_token that leads to the final state."""
55+ return {
56+ self .first_state : {
57+ token_id : self .first_state
58+ if token_id != self .eos_token_id
59+ else self .final_state
60+ for token_id in self .vocabulary .values ()
61+ }
62+ }
63+
64+ def align_prompt_tokens (
65+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
66+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
67+ """Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
68+ (
69+ token_ids ,
70+ attention_masks ,
71+ self .states_to_token_maps ,
72+ ) = align_tokens_states_to_token_maps (
73+ token_ids , attention_masks , self .vocabulary , self .states_to_token_maps
74+ )
75+ return token_ids , attention_masks
4176
4277 def allowed_token_ids (self , state : FSMState ) -> List [int ]:
4378 """Generate a list of allowed tokens for the next step.
4479
45- When in the initial state we allow every token to be generated.
4680 In the final state the only allowed token is `stop_token_id`.
81+ Otherwise we allow the valid transitions tokens corresponding to
82+ the current state of the states_to_token_maps
4783
4884 Parameters
4985 ----------
@@ -57,14 +93,13 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
5793 """
5894 if self .is_final_state (state ):
5995 return [self .eos_token_id ]
60- return list (self .vocabulary )
96+ return list (self .states_to_token_maps [ state ]. keys () )
6197
6298 def next_state (self , state : FSMState , token_id : int ) -> FSMState :
6399 """Update the state of the FSM.
64100
65- The FSM stays in the initial state `0` unless the specified stop token
66- has been generated or the maximum number of tokens has been reached. In
67- which case the FSM moves to the final state `-1`.
101+ The FSM transitions from a state to the other through the
102+ states_to_token_maps until the final state is reached.
68103
69104 Parameters
70105 ----------
@@ -78,14 +113,14 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
78113 The new state of the FSM.
79114
80115 """
81- if token_id == self .eos_token_id :
116+ if self .is_final_state ( state ) :
82117 return self .final_state
83118
84- return self .first_state
119+ return FSMState ( self .states_to_token_maps [ state ][ token_id ])
85120
86121 def copy (self ) -> "StopAtEosFSM" :
87122 """Create a copy of the FSM."""
88- return self
123+ return deepcopy ( self )
89124
90125
91126class RegexFSM (FSM ):
@@ -121,9 +156,22 @@ def create_states_mapping(
121156 self .states_to_token_maps , self .empty_token_ids = create_states_mapping (
122157 regex_string , tuple (sorted (tokenizer .vocabulary .items ()))
123158 )
124- self .vocabulary = tokenizer .vocabulary . values ()
159+ self .vocabulary = tokenizer .vocabulary
125160 self .eos_token_id = tokenizer .eos_token_id
126161
162+ def align_prompt_tokens (
163+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
164+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
165+ """Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
166+ (
167+ token_ids ,
168+ attention_masks ,
169+ self .states_to_token_maps ,
170+ ) = align_tokens_states_to_token_maps (
171+ token_ids , attention_masks , self .vocabulary , self .states_to_token_maps
172+ )
173+ return token_ids , attention_masks
174+
127175 def allowed_token_ids (self , state : FSMState ) -> List [int ]:
128176 """Generate a list of allowed tokens for the next step.
129177
@@ -184,7 +232,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
184232
185233 def copy (self ) -> "RegexFSM" :
186234 """Create a copy of the FSM."""
187- return self
235+ return deepcopy ( self )
188236
189237
190238class CFGFSM (FSM ):
@@ -218,6 +266,12 @@ def __init__(self, cfg_string: str, tokenizer):
218266 self .proposal_last : List [int ] = []
219267 self .regex_fsm_last : RegexFSM
220268
269+ def align_prompt_tokens (
270+ self , token_ids : torch .Tensor , attention_masks : torch .Tensor
271+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
272+ """Not applicable to this type of FSM"""
273+ return token_ids , attention_masks
274+
221275 def allowed_token_ids (self , state : FSMState ) -> List [int ]:
222276 """Generate a list of allowed tokens for the next step.
223277
@@ -333,3 +387,162 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
333387 def copy (self ) -> "CFGFSM" :
334388 """Create a copy of the FSM."""
335389 return CFGFSM (self .cfg_string , self .tokenizer )
390+
391+
392+ def align_tokens_states_to_token_maps (
393+ token_ids : torch .Tensor ,
394+ attention_masks : torch .Tensor ,
395+ vocabulary : Dict [str , int ],
396+ states_to_token_maps : Dict [int , Dict [int , int ]],
397+ ) -> Tuple [torch .Tensor , torch .Tensor , Dict [int , Dict [int , int ]]]:
398+ """Apply token alignment to the provided prompt tokens and attention masks given the
399+ states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
400+ states_to_token_maps"""
401+ prompt_token_ids = token_ids .tolist ()
402+ crossing_tokens = find_crossing_tokens (prompt_token_ids , vocabulary )
403+ valid_crossing_tokens = get_crossing_tokens_target_states (
404+ states_to_token_maps , crossing_tokens , prompt_token_ids , vocabulary
405+ )
406+ if not valid_crossing_tokens :
407+ return token_ids , attention_masks , states_to_token_maps
408+ (
409+ states_to_token_maps ,
410+ number_cropped_tokens ,
411+ ) = add_crossing_tokens_states_to_tokens_map (
412+ states_to_token_maps , prompt_token_ids , valid_crossing_tokens
413+ )
414+ return (
415+ token_ids [:- number_cropped_tokens ],
416+ attention_masks [:- number_cropped_tokens ],
417+ states_to_token_maps ,
418+ )
419+
420+
421+ def find_crossing_tokens (
422+ token_ids : List [int ], vocabulary : Dict [str , int ]
423+ ) -> Dict [int , List [int ]]:
424+ """Find the tokens that could replace one or more tokens at the end of token_ids
425+ while conserving the same intial text (and extending it by at least one character).
426+ Return a dictionary with, for the indexes in the token_ids, the associated crossing tokens.
427+ """
428+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
429+ len_token_ids = len (token_ids )
430+ max_length_token_text = max (len (item ) for item in vocabulary .keys ())
431+ characters_considered = ""
432+ crossing_tokens_map = {}
433+
434+ for index , token_id in enumerate (reversed (token_ids )):
435+ characters_considered = reversed_vocabulary [token_id ] + characters_considered
436+ if len (characters_considered ) >= max_length_token_text :
437+ break
438+ crossing_token_ids = [
439+ token_id
440+ for text , token_id in vocabulary .items ()
441+ if text .startswith (characters_considered )
442+ and len (text ) > len (characters_considered )
443+ ]
444+ crossing_tokens_map [len_token_ids - index - 1 ] = crossing_token_ids
445+
446+ return crossing_tokens_map
447+
448+
449+ def get_crossing_tokens_target_states (
450+ states_to_tokens_map : Dict [int , Dict [int , int ]],
451+ crossing_tokens : Dict [int , List [int ]],
452+ prompt_token_ids : List [int ],
453+ vocabulary : Dict [str , int ],
454+ ) -> Dict [int , Dict [int , int ]]:
455+ """For each crossing token associated to an index, check that the characters after the boundary
456+ match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
457+ provided indexes, the associated valid tokens with the state they would lead to.
458+ """
459+ reversed_vocabulary = {value : key for key , value in vocabulary .items ()}
460+ prompt_token_texts = [
461+ reversed_vocabulary [token_id ] for token_id in prompt_token_ids
462+ ]
463+
464+ valid_crossing_tokens : Dict [int , Dict [int , int ]] = defaultdict (dict )
465+ for pos , tokens in crossing_tokens .items ():
466+ for token in tokens :
467+ is_valid = True
468+ characters = reversed_vocabulary [token ]
469+ characters_before_border = "" .join (prompt_token_texts [pos :])
470+ characters_after_border = characters [len (characters_before_border ) :]
471+ state = 0
472+ for char in characters_after_border :
473+ char_token = vocabulary .get (char )
474+ try :
475+ state = states_to_tokens_map [state ][char_token ] # type: ignore
476+ except KeyError :
477+ is_valid = False
478+ break
479+ if is_valid :
480+ valid_crossing_tokens [pos ][token ] = state
481+
482+ return valid_crossing_tokens
483+
484+
485+ def add_crossing_tokens_states_to_tokens_map (
486+ states_to_tokens_map : Dict [int , Dict [int , int ]],
487+ prompt_token_ids : List [int ],
488+ crossing_tokens_map : Dict [int , Dict [int , int ]],
489+ ) -> Tuple [Dict [int , Dict [int , int ]], int ]:
490+ """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
491+ the starting state of the fsm as we would include some characters at the end of the prompt in
492+ the states_to_tokens_map.
493+ Attention! the starting state of the states_to_tokens_map provided must be 0.
494+ Return the updated states_to_tokens_map and the number of cropped tokens/additional states
495+ """
496+ if not crossing_tokens_map :
497+ return states_to_tokens_map , 0
498+ first_crossing_token_pos = min (
499+ [key for key , value in crossing_tokens_map .items () if value ]
500+ )
501+ number_additional_states = len (prompt_token_ids ) - first_crossing_token_pos
502+ highest_state = max (
503+ max (states_to_tokens_map .keys ()),
504+ max (max (items .values ()) for items in states_to_tokens_map .values ()),
505+ )
506+
507+ for i in range (number_additional_states ):
508+ # add the tokens that was originally part of the prompt
509+ if i == number_additional_states - 1 :
510+ states_to_tokens_map [highest_state + 1 + i ] = {
511+ prompt_token_ids [first_crossing_token_pos + i ]: 0
512+ }
513+ else :
514+ states_to_tokens_map [highest_state + 1 + i ] = {
515+ prompt_token_ids [first_crossing_token_pos + i ]: highest_state + 2 + i
516+ }
517+ # add the crossing tokens
518+ crossing_tokens = crossing_tokens_map .get (first_crossing_token_pos + i )
519+ if crossing_tokens :
520+ for token , target_state in crossing_tokens .items ():
521+ states_to_tokens_map [highest_state + 1 + i ][token ] = target_state
522+
523+ # set the id of our new initial state to 0
524+ states_to_tokens_map = swap_state_ids_states_to_tokens_map (
525+ states_to_tokens_map , highest_state + 1 , 0
526+ )
527+ return states_to_tokens_map , number_additional_states
528+
529+
530+ def swap_state_ids_states_to_tokens_map (
531+ states_to_tokens_map : Dict [int , Dict [int , int ]],
532+ first_state_id : int ,
533+ second_state_id : int ,
534+ ) -> Dict [int , Dict [int , int ]]:
535+ """Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
536+ first_state_transitions = states_to_tokens_map .pop (first_state_id )
537+ second_state_transitions = states_to_tokens_map .pop (second_state_id )
538+ states_to_tokens_map [first_state_id ] = second_state_transitions
539+ states_to_tokens_map [second_state_id ] = first_state_transitions
540+
541+ for transitions in states_to_tokens_map .values ():
542+ for token , target_state_id in list (transitions .items ()):
543+ if target_state_id == first_state_id :
544+ transitions [token ] = second_state_id
545+ elif target_state_id == second_state_id :
546+ transitions [token ] = first_state_id
547+
548+ return states_to_tokens_map
0 commit comments