1- from typing import Iterator , List , Optional , Union
1+ from typing import Iterator , List , Optional , Tuple , Union
22
33import torch
44
5- from outlines .fsm .fsm import FSMState
5+ from outlines .fsm .fsm import FSM , FSMState
66from outlines .generate .generator import sequence_generator
77
88
@@ -21,6 +21,53 @@ def __init__(
2121 self .device = device
2222 self .num_samples = sampler .samples
2323
24+ def align_prompt_tokens (
25+ self ,
26+ prompt_token_ids : torch .Tensor ,
27+ attention_masks : torch .Tensor ,
28+ fsms : List [FSM ],
29+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
30+ """Implement token alignment for each fsm. Return the updated tokens_ids and attention_masks"""
31+ aligned_prompts , aligned_masks = zip (
32+ * [
33+ fsm .align_prompt_tokens (prompt , mask )
34+ for prompt , mask , fsm in zip (prompt_token_ids , attention_masks , fsms )
35+ ]
36+ )
37+ # We have to pad some of the prompts if they are not all of the same length after this operation
38+ max_length_aligned_prompt = max (prompt .shape [0 ] for prompt in aligned_prompts )
39+ padded_aligned_prompts = [
40+ torch .cat (
41+ [
42+ torch .full (
43+ (max_length_aligned_prompt - prompt .shape [0 ],),
44+ 0 ,
45+ device = prompt_token_ids .device ,
46+ dtype = prompt .dtype ,
47+ ),
48+ prompt ,
49+ ]
50+ )
51+ for prompt in aligned_prompts
52+ ]
53+ padded_aligned_masks = [
54+ torch .cat (
55+ [
56+ torch .full (
57+ (max_length_aligned_prompt - mask .shape [0 ],),
58+ 0 ,
59+ device = prompt_token_ids .device ,
60+ dtype = mask .dtype ,
61+ ),
62+ mask ,
63+ ]
64+ )
65+ for mask in aligned_masks
66+ ]
67+ aligned_prompt_token_ids = torch .stack (padded_aligned_prompts )
68+ aligned_attention_masks = torch .stack (padded_aligned_masks )
69+ return aligned_prompt_token_ids , aligned_attention_masks
70+
2471 def get_generated_token_ids (
2572 self ,
2673 prompt_token_ids : torch .Tensor ,
@@ -189,49 +236,15 @@ def __call__(
189236 num_samples = self .num_samples
190237 batch_size = len (prompts )
191238
192- fsm_states = [FSMState (0 ) for _ in range (batch_size * num_samples )]
193- fsms = [self .fsm .copy () for _ in range (batch_size * num_samples )]
194-
195239 prompt_token_ids = torch .repeat_interleave (prompt_token_ids , num_samples , dim = 0 )
196240 attention_masks = torch .repeat_interleave (attention_masks , num_samples , dim = 0 )
197241
198- # Token alignment may shorten some of the prompts by removing tokens at their end.
199- # We have to pad some of the prompts if they are not all of the same length after this operation
200- aligned_prompts , aligned_masks = zip (
201- * [
202- fsm .align_prompt_tokens (prompt , mask )
203- for prompt , mask , fsm in zip (prompt_token_ids , attention_masks , fsms )
204- ]
242+ fsm_states = [FSMState (0 ) for _ in range (batch_size * num_samples )]
243+ fsms = [self .fsm .copy () for _ in range (batch_size * num_samples )]
244+
245+ aligned_prompt_token_ids , aligned_attention_masks = self .align_prompt_tokens (
246+ prompt_token_ids , attention_masks , fsms
205247 )
206- max_length_aligned_prompt = max (prompt .shape [0 ] for prompt in aligned_prompts )
207- padded_aligned_prompts = [
208- torch .cat (
209- [
210- torch .full (
211- (max_length_aligned_prompt - prompt .shape [0 ],),
212- 0 ,
213- dtype = prompt .dtype ,
214- ),
215- prompt ,
216- ]
217- )
218- for prompt in aligned_prompts
219- ]
220- padded_aligned_masks = [
221- torch .cat (
222- [
223- torch .full (
224- (max_length_aligned_prompt - mask .shape [0 ],),
225- 0 ,
226- dtype = mask .dtype ,
227- ),
228- mask ,
229- ]
230- )
231- for mask in aligned_masks
232- ]
233- aligned_prompt_token_ids = torch .stack (padded_aligned_prompts )
234- aligned_attention_masks = torch .stack (padded_aligned_masks )
235248
236249 weights = torch .zeros (
237250 (batch_size * num_samples ), dtype = torch .float , device = self .device
0 commit comments