Skip to content

Commit 149fb63

Browse files
author
Dmitrii Tarasov
committed
fix: Transformers Model no template cast stop_sequences to list
1 parent 06aee5b commit 149fb63

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/lighteval/models/transformers/transformers_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
is_package_available,
5959
)
6060
from lighteval.utils.parallelism import find_executable_batch_size
61+
from lighteval.utils.utils import as_list
6162

6263

6364
logger = logging.getLogger(__name__)
@@ -689,7 +690,7 @@ def _padded_greedy_until(
689690
# NOTE: we are assuming all items in a batch behave similarly (same
690691
# stop_tokens and max_tokens genrated) which is not necessarily
691692
# the case! Because of that we only use batch size of 1
692-
stop_tokens = [self.tokenizer.eos_token] + batch[0].stop_sequences
693+
stop_tokens = [self.tokenizer.eos_token] + as_list(batch[0].stop_sequences)
693694

694695
max_new_tokens = batch[0].generation_size
695696
num_samples = batch[0].num_samples

tests/unit/models/test_transformers_model.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,5 +426,49 @@ def test_transformers_model_use_chat_template_with_different_model_names(
426426
self.assertEqual(model.use_chat_template, model._tokenizer.chat_template is not None)
427427

428428

429+
class TestTransformersModelNoChatTemplate(unittest.TestCase):
430+
"""Tests for stop_tokens assignment in _padded_greedy_until (lines 684-692)."""
431+
432+
@patch("lighteval.models.transformers.transformers_model.Accelerator")
433+
@patch("lighteval.models.transformers.transformers_model.TransformersModel._generate")
434+
@patch("lighteval.models.transformers.transformers_model.DataLoader")
435+
def test_stop_tokens_without_chat_template_empty_stop_sequences(
436+
self, mock_dataloader, mock_generate, mock_accelerator
437+
):
438+
"""When use_chat_template is False and stop_sequences is empty, stop_tokens is [eos_token] only."""
439+
mock_accelerator_instance = Mock()
440+
mock_accelerator_instance.device = torch.device("cpu")
441+
mock_accelerator_instance.prepare = lambda x: x
442+
mock_accelerator.return_value = mock_accelerator_instance
443+
444+
config = TransformersModelConfig(model_name="gpt2")
445+
model = TransformersModel(config)
446+
model.use_chat_template = False
447+
448+
doc = Doc(
449+
query="Say hello.",
450+
choices=[],
451+
gold_index=0,
452+
generation_size=5,
453+
stop_sequences=(), # empty tuple
454+
)
455+
batch_from_dataloader = [doc]
456+
mock_dataloader.return_value = iter([batch_from_dataloader])
457+
458+
captured_stop_tokens = None
459+
460+
def capture_stop_tokens(*args, **kwargs):
461+
nonlocal captured_stop_tokens
462+
captured_stop_tokens = kwargs.get("stop_tokens")
463+
return [ModelResponse(text=[""], logprobs=[], output_tokens=[], input_tokens=[])]
464+
465+
mock_generate.side_effect = capture_stop_tokens
466+
467+
model._padded_greedy_until([doc])
468+
469+
self.assertIsNotNone(captured_stop_tokens)
470+
self.assertEqual(captured_stop_tokens, [model.tokenizer.eos_token])
471+
472+
429473
if __name__ == "__main__":
430474
unittest.main()

0 commit comments

Comments
 (0)