Skip to content

Commit 7a54afa

Browse files
committed
restore line
1 parent c0566ee commit 7a54afa

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/lighteval/models/transformers/transformers_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,12 +684,18 @@ def greedy_until_multi_turn( # noqa: C901
684684
**model_inputs, stopping_criteria=stopping_criteria, **generation_config
685685
)
686686
model_outputs = model_outputs.sequences[0, model_inputs["input_ids"].size(1) :]
687+
688+
# We manage stop tokens in an extra step in case they were incorrectly detected earlier
689+
# (which can happen for multitoken stop sequences)
690+
decoded_generation = self.tokenizer.decode(model_outputs) # should we skip_special_tokens=True here?
691+
for term in stop_tokens:
692+
decoded_generation = decoded_generation.split(term)[0]
687693
model_generations = [model_outputs]
688694

689695
input_tokens = [model_inputs["input_ids"]]
690696

691697
for i, multi_turn_context in enumerate(request.context[1:]):
692-
multi_turn_context = multi_turn_context.format(model_response=model_generations[-1])
698+
multi_turn_context = multi_turn_context.format(model_response=decoded_generation)
693699

694700
model_inputs = self.tokenizer(
695701
multi_turn_context,
@@ -731,9 +737,9 @@ def greedy_until_multi_turn( # noqa: C901
731737
)
732738
model_outputs = model_outputs.sequences[0, model_inputs["input_ids"].size(1) :]
733739
model_generations.append(model_outputs)
734-
decoded_generation = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
735740
input_tokens.append(model_inputs["input_ids"])
736741

742+
decoded_generation = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
737743
for term in stop_tokens:
738744
decoded_generation = decoded_generation.split(term)[0]
739745

0 commit comments

Comments
 (0)