Skip to content

Commit 20461a6

Browse files
committed
Add support for chat template parameters
1 parent 64f93b0 commit 20461a6

File tree

15 files changed

+157
-15
lines changed

15 files changed

+157
-15
lines changed

src/lighteval/logging/info_loggers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class GeneralConfigLogger:
8787
model_size: str = None
8888

8989
generation_parameters: dict | None = None
90+
chat_template_parameters: dict | None = None
9091

9192
# Nanotron config
9293
config: "Config" = None
@@ -129,7 +130,9 @@ def log_args_info(
129130
self.job_id = job_id
130131
self.config = config
131132

132-
def log_model_info(self, generation_parameters: dict, model_info: ModelInfo) -> None:
133+
def log_model_info(
134+
self, generation_parameters: dict, model_info: ModelInfo, chat_template_parameters: dict
135+
) -> None:
133136
"""
134137
Logs the model information.
135138
@@ -139,6 +142,7 @@ def log_model_info(self, generation_parameters: dict, model_info: ModelInfo) ->
139142
140143
"""
141144
self.generation_parameters = generation_parameters
145+
self.chat_template_parameters = chat_template_parameters
142146
self.model_name = model_info.model_name
143147
self.model_sha = model_info.model_sha
144148
self.model_dtype = model_info.model_dtype

src/lighteval/main_baseline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def baseline(
8989
model_dtype=None,
9090
model_size=None,
9191
),
92+
{},
9293
)
9394
evaluation_tracker.task_config_logger.log(tasks_dict)
9495

src/lighteval/models/custom/custom_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,4 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
7070
An example of a custom model can be found in `examples/custom_models/google_translate_model.py`.
7171
"""
7272

73-
model_name: str
7473
model_definition_file_path: str

src/lighteval/models/endpoints/endpoint_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class ServerlessEndpointModelConfig(ModelConfig):
9595
```
9696
"""
9797

98-
model_name: str
9998
add_special_tokens: bool = True
10099
batch_size: int = 1
101100

src/lighteval/models/litellm_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ class LiteLLMModelConfig(ModelConfig):
9494
```
9595
"""
9696

97-
model_name: str
9897
provider: str | None = None
9998
base_url: str | None = None
10099
api_key: str | None = None

src/lighteval/models/model_input.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,18 @@ def to_sglang_dict(self) -> dict:
232232
"min_new_tokens": self.min_new_tokens,
233233
}
234234
return {k: v for k, v in args.items() if v is not None}
235+
236+
237+
class ChatTemplateParameters(BaseModel):
238+
reasoning_effort: str = None
239+
240+
def to_transformers_dict(self) -> dict:
241+
"""Selects relevant chat template parameters for transformers models.
242+
243+
Returns:
244+
dict: Valid parameters for the chat template
245+
"""
246+
args = {
247+
"reasoning_effort": self.reasoning_effort,
248+
}
249+
return {k: v for k, v in args.items() if v is not None}

src/lighteval/models/sglang/sglang_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ class SGLangModelConfig(ModelConfig):
109109
```
110110
"""
111111

112-
model_name: str
113112
load_format: str = "auto"
114113
dtype: str = "auto"
115114
tp_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism

src/lighteval/models/transformers/transformers_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ class TransformersModelConfig(ModelConfig):
133133
(bitsandbytes for 4-bit/8-bit quantization).
134134
"""
135135

136-
model_name: str
137136
tokenizer: str | None = None
138137
subfolder: str | None = None
139138
revision: str = "main"
@@ -230,7 +229,10 @@ def __init__(
230229
)
231230

232231
self.prompt_manager = PromptManager(
233-
use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt
232+
use_chat_template=self.use_chat_template,
233+
tokenizer=self.tokenizer,
234+
system_prompt=config.system_prompt,
235+
chat_template_parameters=config.chat_template_parameters,
234236
)
235237

236238
def cleanup(self):

src/lighteval/models/transformers/vlm_transformers_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ class VLMTransformersModelConfig(ModelConfig):
104104
loading.
105105
"""
106106

107-
model_name: str
108107
processor: str | None = None
109108
use_fast_image_processor: bool | None = None
110109
subfolder: str | None = None

src/lighteval/models/utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from transformers import AutoTokenizer
3535
from transformers.models.auto.configuration_auto import AutoConfig
3636

37-
from lighteval.models.model_input import GenerationParameters
37+
from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters
3838

3939

4040
logger = logging.getLogger(__name__)
@@ -70,7 +70,7 @@ class ModelConfig(BaseModel, extra="forbid"):
7070
config = ModelConfig.from_path("model_config.yaml")
7171
7272
# Load from command line arguments
73-
config = ModelConfig.from_args("model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt='You are a helpful assistant.',generation_parameters={temperature=0.7}")
73+
config = ModelConfig.from_args("model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt='You are a helpful assistant.',generation_parameters={temperature:0.7}")
7474
7575
# Direct instantiation
7676
config = ModelConfig(
@@ -81,7 +81,9 @@ class ModelConfig(BaseModel, extra="forbid"):
8181
```
8282
"""
8383

84+
model_name: str
8485
generation_parameters: GenerationParameters = GenerationParameters()
86+
chat_template_parameters: ChatTemplateParameters = ChatTemplateParameters()
8587
system_prompt: str | None = None
8688

8789
@classmethod
@@ -131,20 +133,31 @@ def _parse_args(args: str) -> dict:
131133
"""
132134
# Looking for generation_parameters in the model_args
133135
generation_parameters_dict = None
134-
pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)")
136+
chat_template_parameters_dict = None
137+
pattern = re.compile(r"(\w+)\s*=\s*(\{[^{}]*\}|[^,]+?)(?=,|$)")
135138
matches = pattern.findall(args)
136139
for key, value in matches:
137140
key = key.strip()
138141
if key == "generation_parameters":
139142
gen_params = re.sub(r"(\w+):", r'"\1":', value)
140143
generation_parameters_dict = json.loads(gen_params)
144+
if key == "chat_template_parameters":
145+
# Chat template parameters have strings as values that also need to be quoted
146+
chat_template_params = re.sub(r"(\w+)\s*:\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r'"\1":"\2"', value)
147+
chat_template_parameters_dict = json.loads(chat_template_params)
141148

142149
args = re.sub(r"generation_parameters=\{.*?\},?", "", args).strip(",")
143-
model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")}
150+
args = re.sub(r"chat_template_parameters=\{.*?\},?", "", args).strip(",")
151+
model_config = (
152+
{k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")} if args else {}
153+
)
144154

145155
if generation_parameters_dict is not None:
146156
model_config["generation_parameters"] = generation_parameters_dict
147157

158+
if chat_template_parameters_dict is not None:
159+
model_config["chat_template_parameters"] = chat_template_parameters_dict
160+
148161
return model_config
149162

150163

0 commit comments

Comments
 (0)