Skip to content

Commit ad43991

Browse files
committed
Initial implementation for chat template parameters
1 parent 64f93b0 commit ad43991

File tree

4 files changed

+40
-5
lines changed

4 files changed

+40
-5
lines changed

src/lighteval/models/model_input.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,17 @@ def to_sglang_dict(self) -> dict:
232232
"min_new_tokens": self.min_new_tokens,
233233
}
234234
return {k: v for k, v in args.items() if v is not None}
235+
236+
class ChatTemplateParameters(BaseModel):
237+
reasoning_effort: str = None
238+
239+
def to_transformers_dict(self) -> dict:
240+
"""Selects relevant chat template parameters for transformers models.
241+
242+
Returns:
243+
dict: Valid parameters for the chat template
244+
"""
245+
args = {
246+
"reasoning_effort": self.reasoning_effort,
247+
}
248+
return {k: v for k, v in args.items() if v is not None}

src/lighteval/models/transformers/transformers_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ def __init__(
230230
)
231231

232232
self.prompt_manager = PromptManager(
233-
use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt
233+
use_chat_template=self.use_chat_template,
234+
tokenizer=self.tokenizer,
235+
system_prompt=config.system_prompt,
236+
chat_template_parameters=config.chat_template_parameters
234237
)
235238

236239
def cleanup(self):

src/lighteval/models/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
from transformers import AutoTokenizer
3535
from transformers.models.auto.configuration_auto import AutoConfig
3636

37-
from lighteval.models.model_input import GenerationParameters
38-
37+
from lighteval.models.model_input import GenerationParameters, ChatTemplateParameters
3938

4039
logger = logging.getLogger(__name__)
4140

@@ -82,6 +81,7 @@ class ModelConfig(BaseModel, extra="forbid"):
8281
"""
8382

8483
generation_parameters: GenerationParameters = GenerationParameters()
84+
chat_template_parameters: ChatTemplateParameters = ChatTemplateParameters()
8585
system_prompt: str | None = None
8686

8787
@classmethod
@@ -131,20 +131,29 @@ def _parse_args(args: str) -> dict:
131131
"""
132132
# Looking for generation_parameters in the model_args
133133
generation_parameters_dict = None
134-
pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)")
134+
chat_template_parameters_dict = None
135+
pattern = re.compile(r'(\w+)\s*=\s*(\{[^{}]*\}|[^,]+?)(?=,|$)')
135136
matches = pattern.findall(args)
136137
for key, value in matches:
137138
key = key.strip()
138139
if key == "generation_parameters":
139140
gen_params = re.sub(r"(\w+):", r'"\1":', value)
140141
generation_parameters_dict = json.loads(gen_params)
142+
if key == "chat_template_parameters":
143+
# Chat template parameters have strings as values that also need to be quoted
144+
chat_template_params = re.sub(r'(\w+)\s*:\s*([A-Za-z_][\w.-]*)\s*(?=[,}])', r'"\1":"\2"', value)
145+
chat_template_parameters_dict = json.loads(chat_template_params)
141146

142147
args = re.sub(r"generation_parameters=\{.*?\},?", "", args).strip(",")
148+
args = re.sub(r"chat_template_parameters=\{.*?\},?", "", args).strip(",")
143149
model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")}
144150

145151
if generation_parameters_dict is not None:
146152
model_config["generation_parameters"] = generation_parameters_dict
147153

154+
if chat_template_parameters_dict is not None:
155+
model_config["chat_template_parameters"] = chat_template_parameters_dict
156+
148157
return model_config
149158

150159

src/lighteval/tasks/prompt_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from itertools import cycle
2929
from typing import TYPE_CHECKING
3030

31+
from lighteval.models.model_input import ChatTemplateParameters
3132
from lighteval.tasks.requests import Doc
3233
from lighteval.utils.utils import as_list
3334

@@ -40,10 +41,17 @@
4041

4142

4243
class PromptManager:
43-
def __init__(self, use_chat_template: bool = False, tokenizer=None, system_prompt: str | None = None):
44+
def __init__(
45+
self,
46+
use_chat_template: bool = False,
47+
tokenizer=None,
48+
system_prompt: str | None = None,
49+
chat_template_parameters: ChatTemplateParameters | None = None,
50+
):
4451
self.use_chat_template = use_chat_template
4552
self.tokenizer = tokenizer
4653
self.system_prompt = system_prompt # System prompt to be used in chat templates
54+
self.chat_template_parameters = chat_template_parameters if chat_template_parameters else {}
4755

4856
def prepare_prompt(self, doc: Doc) -> str:
4957
"""Prepare a prompt from a document, either using chat template or plain text format."""
@@ -123,6 +131,7 @@ def _prepare_chat_template(self, doc: Doc, tokenize: bool = True) -> str:
123131
messages,
124132
tokenize=False,
125133
add_generation_prompt=True,
134+
**self.chat_template_parameters.to_transformers_dict()
126135
)
127136

128137
else: # for apis

0 commit comments

Comments
 (0)