Skip to content

Initial implementation for chat template parameters #904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class GeneralConfigLogger:
model_size: str = None

generation_parameters: dict | None = None
chat_template_parameters: dict | None = None

# Nanotron config
config: "Config" = None
Expand Down Expand Up @@ -129,7 +130,9 @@ def log_args_info(
self.job_id = job_id
self.config = config

def log_model_info(self, generation_parameters: dict, model_info: ModelInfo) -> None:
def log_model_info(
self, generation_parameters: dict, model_info: ModelInfo, chat_template_parameters: dict
) -> None:
"""
Logs the model information.

Expand All @@ -139,6 +142,7 @@ def log_model_info(self, generation_parameters: dict, model_info: ModelInfo) ->

"""
self.generation_parameters = generation_parameters
self.chat_template_parameters = chat_template_parameters
self.model_name = model_info.model_name
self.model_sha = model_info.model_sha
self.model_dtype = model_info.model_dtype
Expand Down
1 change: 1 addition & 0 deletions src/lighteval/main_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def baseline(
model_dtype=None,
model_size=None,
),
{},
)
evaluation_tracker.task_config_logger.log(tasks_dict)

Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/custom/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,4 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
An example of a custom model can be found in `examples/custom_models/google_translate_model.py`.
"""

model_name: str
model_definition_file_path: str
1 change: 0 additions & 1 deletion src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class ServerlessEndpointModelConfig(ModelConfig):
```
"""

model_name: str
add_special_tokens: bool = True
batch_size: int = 1

Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class LiteLLMModelConfig(ModelConfig):
```
"""

model_name: str
provider: str | None = None
base_url: str | None = None
api_key: str | None = None
Expand Down
15 changes: 15 additions & 0 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,18 @@ def to_sglang_dict(self) -> dict:
"min_new_tokens": self.min_new_tokens,
}
return {k: v for k, v in args.items() if v is not None}


class ChatTemplateParameters(BaseModel):
reasoning_effort: str = None

def to_transformers_dict(self) -> dict:
"""Selects relevant chat template parameters for transformers models.

Returns:
dict: Valid parameters for the chat template
"""
args = {
"reasoning_effort": self.reasoning_effort,
}
return {k: v for k, v in args.items() if v is not None}
1 change: 0 additions & 1 deletion src/lighteval/models/sglang/sglang_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ class SGLangModelConfig(ModelConfig):
```
"""

model_name: str
load_format: str = "auto"
dtype: str = "auto"
tp_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism
Expand Down
6 changes: 4 additions & 2 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class TransformersModelConfig(ModelConfig):
(bitsandbytes for 4-bit/8-bit quantization).
"""

model_name: str
tokenizer: str | None = None
subfolder: str | None = None
revision: str = "main"
Expand Down Expand Up @@ -230,7 +229,10 @@ def __init__(
)

self.prompt_manager = PromptManager(
use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt
use_chat_template=self.use_chat_template,
tokenizer=self.tokenizer,
system_prompt=config.system_prompt,
chat_template_parameters=config.chat_template_parameters,
)

def cleanup(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class VLMTransformersModelConfig(ModelConfig):
loading.
"""

model_name: str
processor: str | None = None
use_fast_image_processor: bool | None = None
subfolder: str | None = None
Expand Down
21 changes: 17 additions & 4 deletions src/lighteval/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from transformers import AutoTokenizer
from transformers.models.auto.configuration_auto import AutoConfig

from lighteval.models.model_input import GenerationParameters
from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +70,7 @@ class ModelConfig(BaseModel, extra="forbid"):
config = ModelConfig.from_path("model_config.yaml")

# Load from command line arguments
config = ModelConfig.from_args("model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt='You are a helpful assistant.',generation_parameters={temperature=0.7}")
config = ModelConfig.from_args("model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt='You are a helpful assistant.',generation_parameters={temperature:0.7}")

# Direct instantiation
config = ModelConfig(
Expand All @@ -81,7 +81,9 @@ class ModelConfig(BaseModel, extra="forbid"):
```
"""

model_name: str
generation_parameters: GenerationParameters = GenerationParameters()
chat_template_parameters: ChatTemplateParameters = ChatTemplateParameters()
system_prompt: str | None = None

@classmethod
Expand Down Expand Up @@ -131,20 +133,31 @@ def _parse_args(args: str) -> dict:
"""
# Looking for generation_parameters in the model_args
generation_parameters_dict = None
pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)")
chat_template_parameters_dict = None
pattern = re.compile(r"(\w+)\s*=\s*(\{[^{}]*\}|[^,]+?)(?=,|$)")
matches = pattern.findall(args)
for key, value in matches:
key = key.strip()
if key == "generation_parameters":
gen_params = re.sub(r"(\w+):", r'"\1":', value)
generation_parameters_dict = json.loads(gen_params)
if key == "chat_template_parameters":
# Chat template parameters have strings as values that also need to be quoted
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be relevant to have tests for parsing edge cases

chat_template_params = re.sub(r"(\w+)\s*:\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r'"\1":"\2"', value)
chat_template_parameters_dict = json.loads(chat_template_params)

args = re.sub(r"generation_parameters=\{.*?\},?", "", args).strip(",")
model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")}
args = re.sub(r"chat_template_parameters=\{.*?\},?", "", args).strip(",")
model_config = (
{k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")} if args else {}
)

if generation_parameters_dict is not None:
model_config["generation_parameters"] = generation_parameters_dict

if chat_template_parameters_dict is not None:
model_config["chat_template_parameters"] = chat_template_parameters_dict

return model_config


Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class VLLMModelConfig(ModelConfig):
```
"""

model_name: str
revision: str = "main" # revision of the model
dtype: str = "bfloat16"
tensor_parallel_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism
Expand Down
5 changes: 4 additions & 1 deletion src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,11 @@ def __init__(
self.model = self._init_model(model_config, model)

generation_parameters = model_config.generation_parameters.model_dump() if model_config else {}
chat_template_parameters = model_config.chat_template_parameters.model_dump() if model_config else {}

self.evaluation_tracker.general_config_logger.log_model_info(generation_parameters, self.model.model_info)
self.evaluation_tracker.general_config_logger.log_model_info(
generation_parameters, self.model.model_info, chat_template_parameters
)

self._init_random_seeds()
self._init_tasks_and_requests(tasks=tasks)
Expand Down
11 changes: 10 additions & 1 deletion src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from itertools import cycle
from typing import TYPE_CHECKING

from lighteval.models.model_input import ChatTemplateParameters
from lighteval.tasks.requests import Doc
from lighteval.utils.utils import as_list

Expand All @@ -40,10 +41,17 @@


class PromptManager:
def __init__(self, use_chat_template: bool = False, tokenizer=None, system_prompt: str | None = None):
def __init__(
self,
use_chat_template: bool = False,
tokenizer=None,
system_prompt: str | None = None,
chat_template_parameters: ChatTemplateParameters | None = None,
):
self.use_chat_template = use_chat_template
self.tokenizer = tokenizer
self.system_prompt = system_prompt # System prompt to be used in chat templates
self.chat_template_parameters = chat_template_parameters if chat_template_parameters else {}

def prepare_prompt(self, doc: Doc) -> str:
"""Prepare a prompt from a document, either using chat template or plain text format."""
Expand Down Expand Up @@ -123,6 +131,7 @@ def _prepare_chat_template(self, doc: Doc, tokenize: bool = True) -> str:
messages,
tokenize=False,
add_generation_prompt=True,
**self.chat_template_parameters.to_transformers_dict(),
)

else: # for apis
Expand Down
17 changes: 17 additions & 0 deletions tests/test_prompt_manager_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pytest

from lighteval.models.model_input import ChatTemplateParameters
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc

Expand All @@ -47,6 +48,22 @@ def test_init_with_chat_template(self):
assert pm.tokenizer == tokenizer
assert pm.system_prompt == system_prompt

def test_init_with_chat_template_and_chat_template_parameters(self):
"""Test PromptManager initialization with chat template enabled and chat template parameters."""
tokenizer = Mock()
system_prompt = "You are a helpful assistant."
pm = PromptManager(
use_chat_template=True,
tokenizer=tokenizer,
system_prompt=system_prompt,
chat_template_parameters=ChatTemplateParameters(reasoning_effort="medium"),
)
assert pm.use_chat_template is True
assert pm.tokenizer == tokenizer
assert pm.system_prompt == system_prompt
assert pm.chat_template_parameters is not None
assert pm.chat_template_parameters.reasoning_effort == "medium"

def test_prepare_prompt_plain_text_basic(self):
"""Test prepare_prompt with plain text format and basic document."""
pm = PromptManager()
Expand Down
84 changes: 84 additions & 0 deletions tests/utils/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import unittest

from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters
from lighteval.models.utils import ModelConfig


class TestModelConfig(unittest.TestCase):
def test_model_config_init(self):
config = ModelConfig(
model_name="meta-llama/Llama-3.1-8B-Instruct",
generation_parameters=GenerationParameters(temperature=0.7),
system_prompt="You are a helpful assistant.",
chat_template_parameters=ChatTemplateParameters(reasoning_effort="low"),
)

self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.system_prompt, "You are a helpful assistant.")
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")

def test_model_config_init_command_line(self):
config = ModelConfig.from_args(
'model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt="You are a helpful assistant.",generation_parameters={temperature:0.7},chat_template_parameters={reasoning_effort:low}'
)

self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.system_prompt, '"You are a helpful assistant."') # is this what we want?
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")

def test_model_config_generation_parameters_parse_single_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7}"
)
self.assertEqual(config.generation_parameters.temperature, 0.7)

def test_model_config_generation_parameters_parse_multiple_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7,top_k:42}"
)
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.generation_parameters.top_k, 42)

@unittest.skip("This is not working at this time")
def test_model_config_generation_parameters_parse_string(self):
config = ModelConfig.from_args(
'model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={response_format:{"type":"json_object"}}'
)
self.assertEqual(config.generation_parameters.temperature, 0.7)

@unittest.skip("This is not working at this time")
def test_model_config_chat_template_parameters_parse_single_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={temperature:0.7}"
)
self.assertEqual(config.chat_template_parameters.temperature, 0.7)

def test_model_config_chat_template_parameters_parse_string(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={reasoning_effort:low}"
)
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")