-
Notifications
You must be signed in to change notification settings - Fork 641
Description
Describe the issue as clearly as possible:
Encountered while working on PR 531
When generating several samples for a prompt with the transformers model, the input_ids
gets modified during generation without affecting the final result. This causes problems in the FSMLogitsProcessor.process_logits
method as the keys of self._fsm_states
depend on the values of input_ids
. This is easier to understand by just looking at the example below.
I suppose it's related to "- Ensure FSMLogitsProcessor
allows unstable sequence ordering (beam
search in transformers and vLLM change the order of sequences)" mentioned in this commit by @lapp0, but here it's not just a change of order between the 1st and the 2nd sequence.
Steps/code to reproduce the bug:
import outlines.models as models
import outlines.generate as generate
import outlines.samplers as samplers
model = models.transformers("hf-internal-testing/tiny-random-gpt2", device="cpu")
generator = generate.regex(model, r"([a-z]{3})@", sampler=samplers.beam_search(2))
output = generator(["123"], max_tokens=40)
print(output)
At the beginning of OutlinesLogitsProcessor.__call__
, add:
print(input_ids[0])
print(input_ids[1])
print(self.tokenizer.decode(input_ids[0]))
print(self.tokenizer.decode(input_ids[1]))
print('')
Expected result:
Not have 320 turn into 491 and 'ers' turn into 'age' at the 4th step of the generation for the 2nd sample
Error message:
tensor([17, 18, 19])
tensor([17, 18, 19])
['1', '2', '3']
['1', '2', '3']
tensor([ 17, 18, 19, 491])
tensor([ 17, 18, 19, 320])
['1', '2', '3', 'age']
['1', '2', '3', 'ers']
tensor([ 17, 18, 19, 491, 32])
tensor([ 17, 18, 19, 320, 32])
['1', '2', '3', 'age', '@']
['1', '2', '3', 'ers', '@']
tensor([ 17, 18, 19, 491, 32, 2])
tensor([ 17, 18, 19, 491, 32, 1])
['1', '2', '3', 'age', '@', '"']
['1', '2', '3', 'age', '@', '!']
[['age@', 'ers@']]
Outlines/Python version information:
Current main branch (latest commit 9ce0df3)
Context for the issue:
No response