Skip to content

Unexpected change of input_ids during generation of several samples with transformers #1048

@RobinPicard

Description

@RobinPicard

Describe the issue as clearly as possible:

Encountered while working on PR 531

When generating several samples for a prompt with the transformers model, the input_ids gets modified during generation without affecting the final result. This causes problems in the FSMLogitsProcessor.process_logits method as the keys of self._fsm_states depend on the values of input_ids. This is easier to understand by just looking at the example below.

I suppose it's related to "- Ensure FSMLogitsProcessor allows unstable sequence ordering (beam
search in transformers and vLLM change the order of sequences)" mentioned in this commit by @lapp0, but here it's not just a change of order between the 1st and the 2nd sequence.

Steps/code to reproduce the bug:

import outlines.models as models
import outlines.generate as generate
import outlines.samplers as samplers

model = models.transformers("hf-internal-testing/tiny-random-gpt2", device="cpu")
generator = generate.regex(model, r"([a-z]{3})@", sampler=samplers.beam_search(2))
output = generator(["123"], max_tokens=40)
print(output)

At the beginning of OutlinesLogitsProcessor.__call__, add:

print(input_ids[0])
print(input_ids[1])
print(self.tokenizer.decode(input_ids[0]))
print(self.tokenizer.decode(input_ids[1]))
print('')

Expected result:

Not have 320 turn into 491 and 'ers' turn into 'age' at the 4th step of the generation for the 2nd sample

Error message:

tensor([17, 18, 19])
tensor([17, 18, 19])
['1', '2', '3']
['1', '2', '3']

tensor([ 17,  18,  19, 491])
tensor([ 17,  18,  19, 320])
['1', '2', '3', 'age']
['1', '2', '3', 'ers']

tensor([ 17,  18,  19, 491,  32])
tensor([ 17,  18,  19, 320,  32])
['1', '2', '3', 'age', '@']
['1', '2', '3', 'ers', '@']

tensor([ 17,  18,  19, 491,  32,   2])
tensor([ 17,  18,  19, 491,  32,   1])
['1', '2', '3', 'age', '@', '"']
['1', '2', '3', 'age', '@', '!']

[['age@', 'ers@']]

Outlines/Python version information:

Current main branch (latest commit 9ce0df3)

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions