@@ -426,5 +426,49 @@ def test_transformers_model_use_chat_template_with_different_model_names(
426426 self .assertEqual (model .use_chat_template , model ._tokenizer .chat_template is not None )
427427
428428
429+ class TestTransformersModelNoChatTemplate (unittest .TestCase ):
430+ """Tests for stop_tokens assignment in _padded_greedy_until (lines 684-692)."""
431+
432+ @patch ("lighteval.models.transformers.transformers_model.Accelerator" )
433+ @patch ("lighteval.models.transformers.transformers_model.TransformersModel._generate" )
434+ @patch ("lighteval.models.transformers.transformers_model.DataLoader" )
435+ def test_stop_tokens_without_chat_template_empty_stop_sequences (
436+ self , mock_dataloader , mock_generate , mock_accelerator
437+ ):
438+ """When use_chat_template is False and stop_sequences is empty, stop_tokens is [eos_token] only."""
439+ mock_accelerator_instance = Mock ()
440+ mock_accelerator_instance .device = torch .device ("cpu" )
441+ mock_accelerator_instance .prepare = lambda x : x
442+ mock_accelerator .return_value = mock_accelerator_instance
443+
444+ config = TransformersModelConfig (model_name = "gpt2" )
445+ model = TransformersModel (config )
446+ model .use_chat_template = False
447+
448+ doc = Doc (
449+ query = "Say hello." ,
450+ choices = [],
451+ gold_index = 0 ,
452+ generation_size = 5 ,
453+ stop_sequences = (), # empty tuple
454+ )
455+ batch_from_dataloader = [doc ]
456+ mock_dataloader .return_value = iter ([batch_from_dataloader ])
457+
458+ captured_stop_tokens = None
459+
460+ def capture_stop_tokens (* args , ** kwargs ):
461+ nonlocal captured_stop_tokens
462+ captured_stop_tokens = kwargs .get ("stop_tokens" )
463+ return [ModelResponse (text = ["" ], logprobs = [], output_tokens = [], input_tokens = [])]
464+
465+ mock_generate .side_effect = capture_stop_tokens
466+
467+ model ._padded_greedy_until ([doc ])
468+
469+ self .assertIsNotNone (captured_stop_tokens )
470+ self .assertEqual (captured_stop_tokens , [model .tokenizer .eos_token ])
471+
472+
429473if __name__ == "__main__" :
430474 unittest .main ()
0 commit comments