Skip to content

Commit 327071f

Browse files
authored
set default temperature to 0 in generation config (#814)
* set default temperature to 0 in generation config * issue warning when temperature == 0 with multiple samples * fix test
1 parent 3d44897 commit 327071f

File tree

6 files changed

+18
-4
lines changed

6 files changed

+18
-4
lines changed

src/lighteval/models/litellm_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _prepare_max_new_tokens(self, max_new_tokens):
114114
max_new_tokens = min(max_new_tokens * 10, 32000)
115115
return max_new_tokens
116116

117-
def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence):
117+
def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence): # noqa: C901
118118
"""Make API call with retries."""
119119
response = ModelResponse()
120120
for attempt in range(self.API_MAX_RETRY):
@@ -135,6 +135,12 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
135135
"caching": True,
136136
"api_key": self.api_key,
137137
}
138+
139+
if num_samples > 1 and self.generation_parameters.temperature == 0:
140+
logger.warning(
141+
"num_samples > 1 but temperature is set to 0, this will not sample different outputs."
142+
)
143+
138144
if "o1" in self.model:
139145
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
140146
else:

src/lighteval/models/model_input.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ class GenerationParameters(BaseModel, extra="forbid"):
3636

3737
seed: NonNegativeInt | None = None # vllm, tgi, litellm
3838
stop_tokens: list[str] | None = None # vllm, transformers, tgi, litellm, sglang
39-
temperature: NonNegativeFloat | None = None # vllm, transformers, tgi, litellm, sglang
39+
temperature: NonNegativeFloat = (
40+
0 # vllm, transformers, tgi, litellm, sglang # if not set, defaults to greedy decoding
41+
)
4042
top_k: NonNegativeInt | None = None # vllm, transformers, tgi, sglang
4143
min_p: NonNegativeFloat | None = None # vllm, transformers, sglang
4244
top_p: NonNegativeFloat | None = None # vllm, transformers, tgi, litellm, sglang

src/lighteval/models/sglang/sglang_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ def _generate(
258258
self.sampling_params["max_new_tokens"] = max_new_tokens
259259
self.sampling_params["stop"] = stop_tokens
260260
self.sampling_params["n"] = num_samples
261+
if num_samples > 1 and self.sampling_params["temperature"] == 0:
262+
logger.warning("num_samples > 1 but temperature is set to 0, this will not sample different outputs.")
261263
else:
262264
self.sampling_params["max_new_tokens"] = 1
263265
self.sampling_params["temperature"] = 0

src/lighteval/models/transformers/transformers_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def _generate(
636636
max_new_tokens: int,
637637
stop_tokens: list[str],
638638
returns_logits: Optional[bool] = False,
639-
num_samples: Optional[int] = 1,
639+
num_samples: int = 1,
640640
do_sample: Optional[bool] = False,
641641
) -> list[GenerativeResponse]:
642642
"""Contains the actual logic of the generation.
@@ -655,6 +655,8 @@ def _generate(
655655
output_logits=returns_logits,
656656
renormalize_logits=True,
657657
)
658+
if num_samples > 1 and generation_config["temperature"] == 0:
659+
logger.warning("num_samples > 1 but temperature is set to 0, this will not sample different outputs.")
658660

659661
# Compute model generation
660662
outputs: GenerateOutput = self.model.generate(

src/lighteval/models/vllm/vllm_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def _generate(
336336
sampling_params.stop = stop_tokens
337337
sampling_params.logprobs = 1 if returns_logits else 0
338338

339+
if num_samples > 1 and sampling_params.temperature == 0:
340+
logger.warning("num_samples > 1 but temperature is set to 0, this will not sample different outputs.")
339341
else:
340342
sampling_params.temperature = 0
341343
sampling_params.prompt_logprobs = 1

tests/models/endpoints/test_tgi_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class TestTGIModelConfig:
4747
"repetition_penalty": None,
4848
"seed": None,
4949
"stop_tokens": None,
50-
"temperature": None,
50+
"temperature": 0,
5151
"top_k": None,
5252
"top_p": None,
5353
"truncate_prompt": None,

0 commit comments

Comments
 (0)