Skip to content

Commit 02bf135

Browse files
authored
Adds the option to override the chat template (#914)
1 parent af21193 commit 02bf135

File tree

5 files changed

+45
-5
lines changed

5 files changed

+45
-5
lines changed

src/lighteval/models/sglang/sglang_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class SGLangModelConfig(ModelConfig):
9595
Fraction of GPU memory to use for static allocation. Defaults to 0.8.
9696
chunked_prefill_size (PositiveInt):
9797
Size of chunks for prefill operations. Defaults to 4096.
98+
override_chat_template (bool):
99+
If True, we force the model to use a chat template. If alse, we prevent the model from using
100+
a chat template. If None, we use the default (true if present in the tokenizer, false otherwise)
98101
99102
Example:
100103
```python
@@ -127,6 +130,7 @@ class SGLangModelConfig(ModelConfig):
127130
attention_backend: str | None = None
128131
mem_fraction_static: PositiveFloat = 0.8
129132
chunked_prefill_size: PositiveInt = 4096
133+
override_chat_template: bool = None
130134

131135

132136
class SGLangModel(LightevalModel):
@@ -136,7 +140,9 @@ def __init__(
136140
):
137141
"""Initializes an SGLang model."""
138142
self.config = config
139-
self.use_chat_template = uses_chat_template(model_name=self.config.model_name)
143+
self.use_chat_template = uses_chat_template(
144+
model_name=self.config.model_name, override_chat_template=config.override_chat_template
145+
)
140146
self.data_parallel_size = config.dp_size
141147
self.tensor_parallel_size = config.tp_size
142148
self._add_special_tokens = config.add_special_tokens

src/lighteval/models/transformers/transformers_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ class TransformersModelConfig(ModelConfig):
114114
Whether to tokenize context and continuation separately or together. Defaults to False.
115115
continuous_batching (bool):
116116
Whether to use continuous batching for generation. Defaults to False.
117+
override_chat_template (bool):
118+
If True, we force the model to use a chat template. If alse, we prevent the model from using
119+
a chat template. If None, we use the default (true if present in the tokenizer, false otherwise)
117120
118121
Example:
119122
```python
@@ -151,6 +154,7 @@ class TransformersModelConfig(ModelConfig):
151154
multichoice_continuations_start_space: bool | None = None
152155
pairwise_tokenization: bool = False
153156
continuous_batching: bool = False
157+
override_chat_template: bool = None
154158

155159
def model_post_init(self, __context):
156160
if self.multichoice_continuations_start_space is True:
@@ -201,7 +205,9 @@ def __init__(
201205
self.model_sha = config.get_model_sha()
202206
self._max_length = self._init_max_length()
203207
self._tokenizer = self._create_auto_tokenizer()
204-
self.use_chat_template = uses_chat_template(tokenizer=self._tokenizer)
208+
self.use_chat_template = uses_chat_template(
209+
tokenizer=self._tokenizer, override_chat_template=config.override_chat_template
210+
)
205211
self.model = self._create_auto_model()
206212

207213
# We are in DP (and launch the script with `accelerate launch`)
@@ -285,7 +291,9 @@ def from_model(
285291
else:
286292
self._device = self.config.device
287293

288-
self.use_chat_template = uses_chat_template(self._tokenizer)
294+
self.use_chat_template = uses_chat_template(
295+
tokenizer=self._tokenizer, override_chat_template=config.override_chat_template
296+
)
289297
self._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
290298
self.skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
291299
self.pairwise_tokenization = pairwise_tokenization

src/lighteval/models/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def batched(iterable, n):
109109
yield batch
110110

111111

112-
def uses_chat_template(model_name: str = None, tokenizer: AutoTokenizer = None) -> bool:
112+
def uses_chat_template(
113+
model_name: str = None, tokenizer: AutoTokenizer = None, override_chat_template: bool = None
114+
) -> bool:
113115
"""Returns a boolean depending on whether the Transformers AutoTokenizer contains
114116
a chat template or not
115117
@@ -119,6 +121,8 @@ def uses_chat_template(model_name: str = None, tokenizer: AutoTokenizer = None)
119121
Returns:
120122
bool: True if Tokenizer config contains a chat template, False otherwise
121123
"""
124+
if override_chat_template is not None:
125+
return override_chat_template
122126
if model_name is None and tokenizer is None:
123127
raise Exception("`uses_chat_template` requires either a tokenizer or model name as input")
124128
try:

src/lighteval/models/vllm/vllm_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ class VLLMModelConfig(ModelConfig):
125125
Subfolder within the model repository. Defaults to None.
126126
is_async (bool):
127127
Whether to use the async version of VLLM. Defaults to False.
128+
override_chat_template (bool):
129+
If True, we force the model to use a chat template. If alse, we prevent the model from using
130+
a chat template. If None, we use the default (true if present in the tokenizer, false otherwise)
128131
129132
Example:
130133
```python
@@ -165,6 +168,7 @@ class VLLMModelConfig(ModelConfig):
165168
max_num_batched_tokens: PositiveInt = 2048 # maximum number of tokens per batch
166169
subfolder: str | None = None
167170
is_async: bool = False # Whether to use the async version or sync version of the model
171+
override_chat_template: bool = None
168172

169173

170174
class VLLMModel(LightevalModel):
@@ -174,7 +178,9 @@ def __init__(
174178
):
175179
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
176180
self.config = config
177-
self.use_chat_template = uses_chat_template(model_name=config.model_name)
181+
self.use_chat_template = uses_chat_template(
182+
model_name=config.model_name, override_chat_template=config.override_chat_template
183+
)
178184
self.data_parallel_size = config.data_parallel_size
179185
self.tensor_parallel_size = config.tensor_parallel_size
180186
self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False

tests/models/test_model_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,19 @@ def test_uses_chat_template_with_no_chat_template(self):
4242

4343
result = uses_chat_template(tokenizer=mock_tokenizer)
4444
self.assertFalse(result)
45+
46+
def test_uses_chat_template_with_chat_template_present_override(self):
47+
"""Test that uses_chat_template returns True when tokenizer has a chat template."""
48+
mock_tokenizer = Mock()
49+
mock_tokenizer.chat_template = "{% for message in messages %}..."
50+
51+
result = uses_chat_template(tokenizer=mock_tokenizer, override_chat_template=False)
52+
self.assertFalse(result)
53+
54+
def test_uses_chat_template_with_no_chat_template_override(self):
55+
"""Test that uses_chat_template returns False when tokenizer has no chat template."""
56+
mock_tokenizer = Mock()
57+
mock_tokenizer.chat_template = None
58+
59+
result = uses_chat_template(tokenizer=mock_tokenizer, override_chat_template=True)
60+
self.assertTrue(result)

0 commit comments

Comments
 (0)