Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
is_package_available,
)
from lighteval.utils.parallelism import find_executable_batch_size
from lighteval.utils.utils import as_list


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -689,7 +690,7 @@ def _padded_greedy_until(
# NOTE: we are assuming all items in a batch behave similarly (same
# stop_tokens and max_tokens genrated) which is not necessarily
# the case! Because of that we only use batch size of 1
stop_tokens = [self.tokenizer.eos_token] + batch[0].stop_sequences
stop_tokens = [self.tokenizer.eos_token] + as_list(batch[0].stop_sequences)

max_new_tokens = batch[0].generation_size
num_samples = batch[0].num_samples
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/models/test_transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,5 +426,49 @@ def test_transformers_model_use_chat_template_with_different_model_names(
self.assertEqual(model.use_chat_template, model._tokenizer.chat_template is not None)


class TestTransformersModelNoChatTemplate(unittest.TestCase):
"""Tests for stop_tokens assignment in _padded_greedy_until (lines 684-692)."""

@patch("lighteval.models.transformers.transformers_model.Accelerator")
@patch("lighteval.models.transformers.transformers_model.TransformersModel._generate")
@patch("lighteval.models.transformers.transformers_model.DataLoader")
def test_stop_tokens_without_chat_template_empty_stop_sequences(
self, mock_dataloader, mock_generate, mock_accelerator
):
"""When use_chat_template is False and stop_sequences is empty, stop_tokens is [eos_token] only."""
mock_accelerator_instance = Mock()
mock_accelerator_instance.device = torch.device("cpu")
mock_accelerator_instance.prepare = lambda x: x
mock_accelerator.return_value = mock_accelerator_instance

config = TransformersModelConfig(model_name="gpt2")
model = TransformersModel(config)
model.use_chat_template = False

doc = Doc(
query="Say hello.",
choices=[],
gold_index=0,
generation_size=5,
stop_sequences=(), # empty tuple
)
batch_from_dataloader = [doc]
mock_dataloader.return_value = iter([batch_from_dataloader])

captured_stop_tokens = None

def capture_stop_tokens(*args, **kwargs):
nonlocal captured_stop_tokens
captured_stop_tokens = kwargs.get("stop_tokens")
return [ModelResponse(text=[""], logprobs=[], output_tokens=[], input_tokens=[])]

mock_generate.side_effect = capture_stop_tokens

model._padded_greedy_until([doc])

self.assertIsNotNone(captured_stop_tokens)
self.assertEqual(captured_stop_tokens, [model.tokenizer.eos_token])


if __name__ == "__main__":
unittest.main()