Skip to content

Commit 4528931

Browse files
authored
fix: Handle missing prompts in math HF data processor and add regression test (#1219)
Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com>
1 parent 1b96b45 commit 4528931

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

nemo_rl/data/processors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,12 @@ def math_hf_data_processor(
108108
extra_env_info = {"ground_truth": user_message[1]["content"]}
109109

110110
message_log: LLMMessageLogType = []
111+
formatted_content = (
112+
task_data_spec.prompt.format(problem) if task_data_spec.prompt else problem
113+
)
111114
user_message = {
112115
"role": "user",
113-
"content": task_data_spec.prompt.format(problem),
116+
"content": formatted_content,
114117
}
115118
message: list[str] = tokenizer.apply_chat_template( # type: ignore
116119
[user_message],

tests/unit/data/test_data_processor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from collections import defaultdict
1919

2020
import pytest
21+
import torch
2122
from datasets import Dataset
2223

2324
abspath = os.path.abspath(__file__)
@@ -40,6 +41,30 @@
4041
from nemo_rl.models.policy import TokenizerConfig
4142

4243

44+
class DummyTokenizer:
45+
def apply_chat_template(
46+
self,
47+
messages,
48+
tokenize=False,
49+
add_generation_prompt=True,
50+
add_special_tokens=False,
51+
):
52+
content = "".join(
53+
f"{m.get('role', 'user')}: {m['content']}\n" for m in messages
54+
)
55+
if add_generation_prompt:
56+
content += "assistant:"
57+
return content
58+
59+
def __call__(self, text, return_tensors=None, add_special_tokens=False):
60+
if isinstance(text, list):
61+
text = "".join(text)
62+
encoded = list(range(len(text)))
63+
if return_tensors == "pt":
64+
return {"input_ids": torch.tensor([encoded], dtype=torch.long)}
65+
return {"input_ids": encoded}
66+
67+
4368
def test_math_data_processor():
4469
raw_dataset = Dataset.from_list(
4570
[
@@ -131,6 +156,37 @@ def test_math_hf_data_processor(tokenizer_name, dataset_cls):
131156
assert len(first_item["message_log"]) > 0
132157

133158

159+
def test_math_hf_data_processor_without_prompt():
160+
datum_dict = {
161+
"messages": [
162+
{"role": "user", "content": "Solve 1+1."},
163+
{"role": "assistant", "content": "2"},
164+
],
165+
"task_name": "math",
166+
}
167+
tokenizer = DummyTokenizer()
168+
169+
math_task_spec = TaskDataSpec(
170+
task_name="math",
171+
prompt_file=None,
172+
system_prompt_file=None,
173+
)
174+
175+
result = math_hf_data_processor(
176+
datum_dict=datum_dict,
177+
task_data_spec=math_task_spec,
178+
tokenizer=tokenizer,
179+
max_seq_length=128,
180+
idx=0,
181+
)
182+
183+
assert result["extra_env_info"]["ground_truth"] == "2"
184+
assert result["loss_multiplier"] == 1.0
185+
assert len(result["message_log"]) == 1
186+
assert result["message_log"][0]["role"] == "user"
187+
assert "Solve 1+1." in result["message_log"][0]["content"]
188+
189+
134190
@pytest.fixture
135191
def system_prompt_file(request):
136192
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file:

0 commit comments

Comments
 (0)