diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index d4e3d2bd2..16ca69961 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -173,7 +173,7 @@ def to_transformers_dict(self) -> dict: Returns: dict: The parameters to create a transformers.GenerationConfig in the model config. """ - # Task specific sampling params to set in model: do_sample, num_return_sequences, num_beans + # Task specific sampling params to set in model: num_return_sequences, num_beans args = { "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, @@ -187,6 +187,7 @@ def to_transformers_dict(self) -> dict: "length_penalty": self.length_penalty, "output_scores": True, "return_dict_in_generate": True, + "do_sample": self.temperature > 0.0, } return {k: v for k, v in args.items() if v is not None} @@ -197,7 +198,7 @@ def to_tgi_ie_dict(self) -> dict: Returns: dict: The parameters to create a huggingface_hub.TextGenerationInputGenerateParameters in the model config. """ - # Task specific sampling params to set in model: best_of, do_sample + # Task specific sampling params to set in model: best_of args = { "decoder_input_details": True, "details": True, @@ -210,6 +211,7 @@ def to_tgi_ie_dict(self) -> dict: "top_k": self.top_k, "top_p": self.top_p, "truncate": self.truncate_prompt, + "do_sample": self.temperature > 0.0, } return {k: v for k, v in args.items() if v is not None}