Skip to content

Commit 81678d0

Browse files
committed
Create align_prompt_tokens method for Guide classes
1 parent 70265fa commit 81678d0

File tree

5 files changed

+565
-12
lines changed

5 files changed

+565
-12
lines changed

outlines/fsm/guide.py

Lines changed: 251 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
from copy import deepcopy
13
from dataclasses import dataclass
24
from typing import (
35
TYPE_CHECKING,
@@ -78,14 +80,17 @@ def get_next_state(self, state: int, token_id: int) -> int:
7880
def is_final_state(self, state: int) -> bool:
7981
...
8082

83+
def align_prompt_tokens(self, prompt: str, tokenizer: "Tokenizer") -> str:
84+
...
85+
8186
def copy(self) -> "Guide":
8287
...
8388

8489

8590
class StopAtEOSGuide(Guide):
8691
"""Guide to generate tokens until the EOS token has been generated."""
8792

88-
final_state = 1
93+
final_state = -1
8994
start_state = 0
9095

9196
def __init__(self, tokenizer: "Tokenizer"):
@@ -96,24 +101,69 @@ def __init__(self, tokenizer: "Tokenizer"):
96101
97102
"""
98103
self.eos_token_id = tokenizer.eos_token_id
99-
self.vocabulary = tokenizer.vocabulary.values()
104+
self.vocabulary = tokenizer.vocabulary
105+
self.states_to_token_maps = self.create_states_to_tokens_map()
106+
107+
def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
108+
"""Create the states_to_tokens_map. All tokens lead to the starting
109+
state, except for the eos_token that leads to the final state."""
110+
return {
111+
self.start_state: {
112+
token_id: self.start_state
113+
if token_id != self.eos_token_id
114+
else self.final_state
115+
for token_id in self.vocabulary.values()
116+
}
117+
}
118+
119+
def align_prompt_tokens(self, prompt: str, tokenizer: "Tokenizer") -> str:
120+
"""Update the states_to_token_maps and return the aligned prompt"""
121+
token_ids, _ = tokenizer.encode(prompt)
122+
# possible return types of tokenizers include list, 1d Tensor and 2d Tensor
123+
if not isinstance(token_ids, list):
124+
token_ids = token_ids.tolist()
125+
if isinstance(token_ids[0], list):
126+
token_ids = token_ids[0]
127+
(
128+
aligned_token_ids,
129+
aligned_states_to_token_maps,
130+
) = align_tokens_states_to_token_maps(
131+
token_ids, self.vocabulary, deepcopy(self.states_to_token_maps)
132+
)
133+
# some tokenizer expect a list of lists while others expect a simple list
134+
aligned_prompt: list
135+
try:
136+
aligned_prompt = tokenizer.decode([aligned_token_ids])
137+
except TypeError:
138+
aligned_prompt = tokenizer.decode(aligned_token_ids)
139+
# some models do not accept an empty string as a prompt
140+
# if token alignement would remove all tokens, do not apply it
141+
if not aligned_prompt or not aligned_prompt[0]:
142+
return prompt
143+
aligned_prompt = aligned_prompt[0]
144+
print(prompt, token_ids, aligned_token_ids, aligned_prompt)
145+
self.states_to_token_maps = aligned_states_to_token_maps
146+
# remove leading whitespace if added by the tokenizer
147+
if aligned_prompt[0] == " " and prompt[0] != " ":
148+
return aligned_prompt[1:]
149+
return aligned_prompt
100150

101151
def get_next_instruction(self, state: int) -> Instruction:
102152
if self.is_final_state(state):
103153
return Write([self.eos_token_id])
104-
return Generate(None)
154+
return Generate(list(self.states_to_token_maps[state].keys()))
105155

106156
def get_next_state(self, state: int, token_id: int) -> int:
107-
if token_id == self.eos_token_id or state == self.final_state:
157+
if self.is_final_state(state):
108158
return self.final_state
109159

110-
return self.start_state
160+
return self.states_to_token_maps[state][token_id]
111161

112162
def is_final_state(self, state: int):
113163
return state == self.final_state
114164

115165
def copy(self):
116-
return self
166+
return deepcopy(self)
117167

118168

119169
@cache()
@@ -182,9 +232,41 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
182232
self.empty_token_ids,
183233
fsm_finals,
184234
) = create_states_mapping(regex_string, tokenizer)
235+
self._cache_state_to_token_tensor()
236+
self.vocabulary = tokenizer.vocabulary
185237
self.eos_token_id = tokenizer.eos_token_id
186238
self.final_states = fsm_finals | {-1}
239+
240+
def align_prompt_tokens(self, prompt: str, tokenizer: "Tokenizer") -> str:
241+
"""Update the states_to_token_maps and return the aligned prompt"""
242+
token_ids, _ = tokenizer.encode(prompt)
243+
# possible return types of tokenizers include list, 1d Tensor and 2d Tensor
244+
if not isinstance(token_ids, list):
245+
token_ids = token_ids.tolist()
246+
if isinstance(token_ids[0], list):
247+
token_ids = token_ids[0]
248+
(
249+
aligned_token_ids,
250+
aligned_states_to_token_maps,
251+
) = align_tokens_states_to_token_maps(
252+
token_ids, self.vocabulary, deepcopy(self.states_to_token_maps)
253+
)
254+
# some tokenizer expect a list of lists while others expect a simple list
255+
aligned_prompt: str
256+
try:
257+
aligned_prompt = tokenizer.decode([aligned_token_ids])[0]
258+
except TypeError:
259+
aligned_prompt = tokenizer.decode(aligned_token_ids)[0]
260+
# some models do not accept an empty string as a prompt
261+
# if token alignement would remove all tokens, do not apply it
262+
if not aligned_prompt:
263+
return prompt
264+
self.states_to_token_maps = aligned_states_to_token_maps
187265
self._cache_state_to_token_tensor()
266+
# remove leading whitespace if added by the tokenizer
267+
if aligned_prompt[0] == " " and prompt[0] != " ":
268+
return aligned_prompt[1:]
269+
return aligned_prompt
188270

189271
def get_next_instruction(self, state: int) -> Instruction:
190272
"""Return the next instruction for guided generation.
@@ -278,6 +360,7 @@ def create_states_mapping_from_interegular_fsm(
278360
from_interegular_instance.states_to_token_maps,
279361
from_interegular_instance.empty_token_ids,
280362
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
363+
from_interegular_instance.vocabulary = tokenizer.vocabulary
281364
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
282365
from_interegular_instance._cache_state_to_token_tensor()
283366
return from_interegular_instance
@@ -297,7 +380,7 @@ def is_final_state(self, state: int) -> bool:
297380
return state in self.final_states
298381

299382
def copy(self):
300-
return self
383+
return deepcopy(self)
301384

302385

303386
class CFGGuide(Guide):
@@ -334,6 +417,10 @@ def __init__(self, cfg_string: str, tokenizer):
334417
self.start_state = 0
335418
self.final_state = -1
336419

420+
def align_prompt_tokens(self, prompt: str, tokenizer: "Tokenizer") -> str:
421+
"""Not applicable to this type of Guide"""
422+
return prompt
423+
337424
def get_next_instruction(self, state: int) -> Instruction:
338425
"""Generate an instruction for the next step.
339426
@@ -475,3 +562,160 @@ def is_final_state(self, state: int) -> bool:
475562
def copy(self) -> "CFGGuide":
476563
"""Create a copy of the FSM."""
477564
return CFGGuide(self.cfg_string, self.tokenizer)
565+
566+
567+
def align_tokens_states_to_token_maps(
568+
token_ids: List[int],
569+
vocabulary: Dict[str, int],
570+
states_to_token_maps: Dict[int, Dict[int, int]],
571+
) -> Tuple[List[int], Dict[int, Dict[int, int]]]:
572+
"""Apply token alignment to the provided prompt tokens and attention masks given the
573+
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
574+
states_to_token_maps"""
575+
crossing_tokens = find_crossing_tokens(token_ids, vocabulary)
576+
valid_crossing_tokens = get_crossing_tokens_target_states(
577+
states_to_token_maps, crossing_tokens, token_ids, vocabulary
578+
)
579+
if not valid_crossing_tokens:
580+
return token_ids, states_to_token_maps
581+
(
582+
states_to_token_maps,
583+
number_cropped_tokens,
584+
) = add_crossing_tokens_states_to_tokens_map(
585+
states_to_token_maps, token_ids, valid_crossing_tokens
586+
)
587+
return (
588+
token_ids[:-number_cropped_tokens],
589+
states_to_token_maps,
590+
)
591+
592+
593+
def find_crossing_tokens(
594+
token_ids: List[int], vocabulary: Dict[str, int]
595+
) -> Dict[int, List[int]]:
596+
"""Find the tokens that could replace one or more tokens at the end of token_ids
597+
while conserving the same intial text (and extending it by at least one character).
598+
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
599+
"""
600+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
601+
len_token_ids = len(token_ids)
602+
max_length_token_text = max(len(item) for item in vocabulary.keys())
603+
characters_considered = ""
604+
crossing_tokens_map = {}
605+
606+
for index, token_id in enumerate(reversed(token_ids)):
607+
characters_considered = reversed_vocabulary[token_id] + characters_considered
608+
if len(characters_considered) >= max_length_token_text:
609+
break
610+
crossing_token_ids = [
611+
token_id
612+
for text, token_id in vocabulary.items()
613+
if text.startswith(characters_considered)
614+
and len(text) > len(characters_considered)
615+
]
616+
if crossing_token_ids:
617+
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids
618+
619+
return crossing_tokens_map
620+
621+
622+
def get_crossing_tokens_target_states(
623+
states_to_tokens_map: Dict[int, Dict[int, int]],
624+
crossing_tokens: Dict[int, List[int]],
625+
prompt_token_ids: List[int],
626+
vocabulary: Dict[str, int],
627+
) -> Dict[int, Dict[int, int]]:
628+
"""For each crossing token associated to an index, check that the characters after the boundary
629+
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
630+
provided indexes, the associated valid tokens with the state they would lead to.
631+
"""
632+
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
633+
prompt_token_texts = [
634+
reversed_vocabulary[token_id] for token_id in prompt_token_ids
635+
]
636+
637+
valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
638+
for pos, tokens in crossing_tokens.items():
639+
for token in tokens:
640+
is_valid = True
641+
characters = reversed_vocabulary[token]
642+
characters_before_border = "".join(prompt_token_texts[pos:])
643+
characters_after_border = characters[len(characters_before_border) :]
644+
state = 0
645+
for char in characters_after_border:
646+
char_token = vocabulary.get(char)
647+
try:
648+
state = states_to_tokens_map[state][char_token] # type: ignore
649+
except KeyError:
650+
is_valid = False
651+
break
652+
if is_valid:
653+
valid_crossing_tokens[pos][token] = state
654+
655+
return valid_crossing_tokens
656+
657+
658+
def add_crossing_tokens_states_to_tokens_map(
659+
states_to_tokens_map: Dict[int, Dict[int, int]],
660+
prompt_token_ids: List[int],
661+
crossing_tokens_map: Dict[int, Dict[int, int]],
662+
) -> Tuple[Dict[int, Dict[int, int]], int]:
663+
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
664+
the starting state of the fsm as we would include some characters at the end of the prompt in
665+
the states_to_tokens_map.
666+
Attention! the starting state of the states_to_tokens_map provided must be 0.
667+
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
668+
"""
669+
if not crossing_tokens_map:
670+
return states_to_tokens_map, 0
671+
first_crossing_token_pos = min(
672+
[key for key, value in crossing_tokens_map.items() if value]
673+
)
674+
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
675+
highest_state = max(
676+
max(states_to_tokens_map.keys()),
677+
max(max(items.values()) for items in states_to_tokens_map.values()),
678+
)
679+
680+
for i in range(number_additional_states):
681+
# add the tokens that was originally part of the prompt
682+
if i == number_additional_states - 1:
683+
states_to_tokens_map[highest_state + 1 + i] = {
684+
prompt_token_ids[first_crossing_token_pos + i]: 0
685+
}
686+
else:
687+
states_to_tokens_map[highest_state + 1 + i] = {
688+
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
689+
}
690+
# add the crossing tokens
691+
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
692+
if crossing_tokens:
693+
for token, target_state in crossing_tokens.items():
694+
states_to_tokens_map[highest_state + 1 + i][token] = target_state
695+
696+
# set the id of our new initial state to 0
697+
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
698+
states_to_tokens_map, highest_state + 1, 0
699+
)
700+
return states_to_tokens_map, number_additional_states
701+
702+
703+
def swap_state_ids_states_to_tokens_map(
704+
states_to_tokens_map: Dict[int, Dict[int, int]],
705+
first_state_id: int,
706+
second_state_id: int,
707+
) -> Dict[int, Dict[int, int]]:
708+
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
709+
first_state_transitions = states_to_tokens_map.pop(first_state_id)
710+
second_state_transitions = states_to_tokens_map.pop(second_state_id)
711+
states_to_tokens_map[first_state_id] = second_state_transitions
712+
states_to_tokens_map[second_state_id] = first_state_transitions
713+
714+
for transitions in states_to_tokens_map.values():
715+
for token, target_state_id in list(transitions.items()):
716+
if target_state_id == first_state_id:
717+
transitions[token] = second_state_id
718+
elif target_state_id == second_state_id:
719+
transitions[token] = first_state_id
720+
721+
return states_to_tokens_map

tests/fsm/test_fsm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class MockTokenizer:
1818
with pytest.warns(UserWarning):
1919
fsm = StopAtEosFSM(MockTokenizer())
2020

21-
assert fsm.allowed_token_ids(fsm.start_state) is None
21+
assert fsm.allowed_token_ids(fsm.start_state) == [1, 2]
2222
assert fsm.allowed_token_ids(fsm.final_state) == [2]
2323
assert fsm.next_state(fsm.start_state, 2) == fsm.final_state
2424
assert fsm.next_state(fsm.start_state, 1) == fsm.start_state

0 commit comments

Comments
 (0)