Skip to content

Commit 2b55598

Browse files
fix: address double bos in eval task (#962)
Signed-off-by: Zhiyu Li <[email protected]> Signed-off-by: Zhiyu Li <[email protected]> Co-authored-by: Yuki Huang <[email protected]>
1 parent acabc79 commit 2b55598

File tree

2 files changed

+89
-11
lines changed

2 files changed

+89
-11
lines changed

nemo_rl/data/processors.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def math_data_processor(
5151
add_generation_prompt=False,
5252
add_special_tokens=False,
5353
)
54-
sys_prompt["token_ids"] = tokenizer(sys, return_tensors="pt")["input_ids"][0]
54+
sys_prompt["token_ids"] = tokenizer(
55+
sys, return_tensors="pt", add_special_tokens=False
56+
)["input_ids"][0]
5557
message_log.append(sys_prompt)
5658

5759
# user prompt
@@ -138,7 +140,9 @@ def multichoice_qa_processor(
138140
add_generation_prompt=False,
139141
add_special_tokens=False,
140142
)
141-
sys_prompt["token_ids"] = tokenizer(sys, return_tensors="pt")["input_ids"][0]
143+
sys_prompt["token_ids"] = tokenizer(
144+
sys, return_tensors="pt", add_special_tokens=False
145+
)["input_ids"][0]
142146
message_log.append(sys_prompt)
143147

144148
# user prompt
@@ -153,7 +157,9 @@ def multichoice_qa_processor(
153157
add_generation_prompt=True,
154158
add_special_tokens=False,
155159
)
156-
user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0]
160+
user_message["token_ids"] = tokenizer(
161+
message, return_tensors="pt", add_special_tokens=False
162+
)["input_ids"][0]
157163
user_message["content"] = message
158164
message_log.append(user_message)
159165

tests/unit/data/test_data_processor.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import sys
17+
import tempfile
1718
from collections import defaultdict
1819

1920
import pytest
@@ -25,6 +26,13 @@
2526
from examples.run_grpo_math import hf_data_processor
2627
from nemo_rl.algorithms.utils import get_tokenizer
2728
from nemo_rl.data.datasets import AllTaskProcessedDataset
29+
from nemo_rl.data.eval_datasets import (
30+
AIME2024Dataset,
31+
AIME2025Dataset,
32+
GPQADataset,
33+
MathDataset,
34+
MMLUDataset,
35+
)
2836
from nemo_rl.data.hf_datasets.deepscaler import DeepScalerDataset
2937
from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset
3038
from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec
@@ -78,18 +86,15 @@ def test_math_data_processor():
7886
],
7987
)
8088
@pytest.mark.parametrize(
81-
"dataset_name",
89+
"dataset_cls",
8290
[
83-
"openmathinstruct2",
84-
"deepscaler",
91+
OpenMathInstruct2Dataset,
92+
DeepScalerDataset,
8593
],
8694
)
87-
def test_math_hf_data_processor(tokenizer_name, dataset_name):
95+
def test_math_hf_data_processor(tokenizer_name, dataset_cls):
8896
# Initialize dataset
89-
if dataset_name == "openmathinstruct2":
90-
data = OpenMathInstruct2Dataset()
91-
elif dataset_name == "deepscaler":
92-
data = DeepScalerDataset()
97+
data = dataset_cls()
9398

9499
# Setup tokenizer
95100
tokenizer = get_tokenizer(
@@ -124,3 +129,70 @@ def test_math_hf_data_processor(tokenizer_name, dataset_name):
124129
assert first_item is not None
125130
assert "message_log" in first_item
126131
assert len(first_item["message_log"]) > 0
132+
133+
134+
@pytest.fixture
135+
def system_prompt_file(request):
136+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file:
137+
file.write("You are a helpful assistant.\n{}")
138+
139+
return file.name
140+
141+
142+
@pytest.mark.hf_gated
143+
@pytest.mark.parametrize(
144+
"tokenizer_name",
145+
[
146+
"meta-llama/Llama-3.2-1B-Instruct",
147+
"Qwen/Qwen2.5-1.5B-Instruct", # no bos token
148+
"google/gemma-3-1b-it",
149+
"Qwen/Qwen3-0.6B", # no bos token
150+
"deepseek-ai/DeepSeek-V3",
151+
"moonshotai/Moonlight-16B-A3B-Instruct",
152+
],
153+
)
154+
@pytest.mark.parametrize(
155+
"dataset_cls",
156+
[
157+
MMLUDataset,
158+
GPQADataset,
159+
MathDataset,
160+
AIME2024Dataset,
161+
AIME2025Dataset,
162+
],
163+
)
164+
@pytest.mark.parametrize(
165+
"system_prompt_file", [system_prompt_file, None], indirect=True
166+
)
167+
def test_eval_math_hf_data_processor(tokenizer_name, dataset_cls, system_prompt_file):
168+
# Initialize dataset
169+
data = dataset_cls()
170+
171+
# Setup tokenizer
172+
tokenizer = get_tokenizer(
173+
TokenizerConfig(
174+
name=tokenizer_name,
175+
chat_template="default",
176+
)
177+
)
178+
179+
# Configure task specification
180+
math_task_spec = TaskDataSpec(
181+
task_name="math",
182+
prompt_file=f"{os.path.dirname(abspath)}/../../../examples/prompts/cot.txt",
183+
system_prompt_file=system_prompt_file,
184+
)
185+
186+
dataset = AllTaskProcessedDataset(
187+
dataset=data.rekeyed_ds,
188+
tokenizer=tokenizer,
189+
default_task_data_spec=math_task_spec,
190+
task_data_processors=data.processor,
191+
max_seq_length=128,
192+
)
193+
194+
# Test that the first item can be retrieved when the BOS token assertion passes
195+
first_item = dataset[0]
196+
assert first_item is not None
197+
assert "message_log" in first_item
198+
assert len(first_item["message_log"]) > 0

0 commit comments

Comments
 (0)