Skip to content

Commit ce2d121

Browse files
authored
fix: fix chat_template in eval (#210)
Signed-off-by: Yuki Huang <[email protected]>
1 parent f8b6ba9 commit ce2d121

File tree

4 files changed

+66
-4
lines changed

4 files changed

+66
-4
lines changed

examples/configs/eval.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ generation:
1414
gpu_memory_utilization: 0.9
1515
max_model_len: 2048
1616

17+
tokenizer:
18+
name: ${generation.model_name} ## specify if you'd like to use a tokenizer different from the model's default
19+
chat_template: "default"
20+
1721
data:
1822
max_input_seq_length: ${generation.vllm_cfg.max_model_len} # useless since we directly use prompts in evaluation
1923
prompt_file: null

examples/run_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def main():
114114
init_ray()
115115

116116
# Setup tokenizer
117-
tokenizer = get_tokenizer(config["generation"]["model_name"])
117+
tokenizer = get_tokenizer(config["tokenizer"])
118118
config["generation"] = configure_generation_config(
119119
config["generation"], tokenizer, is_eval=True
120120
)

examples/run_grpo_math.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,13 @@ def math_data_processor(
114114
solution = str(datum_dict["expected_answer"])
115115
extra_env_info = {"ground_truth": solution}
116116

117-
template = task_data_spec.custom_template
118117
message_log: LLMMessageLogType = []
119118

120119
# system prompt
121120
if task_data_spec.system_prompt:
122121
sys_message = {"role": "system", "content": task_data_spec.system_prompt}
123122
message = tokenizer.apply_chat_template(
124123
[sys_message],
125-
chat_template=template,
126124
tokenize=False,
127125
add_generation_prompt=False,
128126
add_special_tokens=False,
@@ -138,7 +136,6 @@ def math_data_processor(
138136
user_message = {"role": "user", "content": problem}
139137
message = tokenizer.apply_chat_template(
140138
[user_message],
141-
chat_template=template,
142139
tokenize=False,
143140
add_generation_prompt=True,
144141
add_special_tokens=False,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import pytest
17+
import sys
18+
from datasets import Dataset
19+
20+
abspath = os.path.abspath(__file__)
21+
sys.path.append("/".join(abspath.split("/")[:-4]))
22+
23+
from examples.run_grpo_math import math_data_processor
24+
from nemo_reinforcer.algorithms.utils import get_tokenizer
25+
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset
26+
from nemo_reinforcer.data.interfaces import TaskDataSpec
27+
from nemo_reinforcer.models.policy import TokenizerConfig
28+
29+
30+
basic_tokenizer_test_config: TokenizerConfig = {
31+
"name": "Qwen/Qwen2.5-Math-1.5B-Instruct",
32+
"chat_template": "default",
33+
}
34+
35+
36+
def test_math_data_processor():
37+
raw_dataset = Dataset.from_list(
38+
[
39+
{"problem": "problem1", "expected_answer": "answer1"},
40+
{"problem": "problem2", "expected_answer": "answer2"},
41+
]
42+
)
43+
44+
tokenizer = get_tokenizer(basic_tokenizer_test_config)
45+
46+
math_task_spec = TaskDataSpec(
47+
task_name="math",
48+
prompt_file=None,
49+
system_prompt_file=None,
50+
)
51+
52+
dataset = AllTaskProcessedDataset(
53+
dataset=raw_dataset,
54+
tokenizer=tokenizer,
55+
default_task_data_spec=math_task_spec,
56+
task_data_processors=math_data_processor,
57+
max_seq_length=128,
58+
)
59+
60+
assert dataset[0]["extra_env_info"]["ground_truth"] == "answer1"
61+
assert dataset[1]["extra_env_info"]["ground_truth"] == "answer2"

0 commit comments

Comments
 (0)