|
18 | 18 | from collections import defaultdict |
19 | 19 |
|
20 | 20 | import pytest |
| 21 | +import torch |
21 | 22 | from datasets import Dataset |
22 | 23 |
|
23 | 24 | abspath = os.path.abspath(__file__) |
|
40 | 41 | from nemo_rl.models.policy import TokenizerConfig |
41 | 42 |
|
42 | 43 |
|
| 44 | +class DummyTokenizer: |
| 45 | + def apply_chat_template( |
| 46 | + self, |
| 47 | + messages, |
| 48 | + tokenize=False, |
| 49 | + add_generation_prompt=True, |
| 50 | + add_special_tokens=False, |
| 51 | + ): |
| 52 | + content = "".join( |
| 53 | + f"{m.get('role', 'user')}: {m['content']}\n" for m in messages |
| 54 | + ) |
| 55 | + if add_generation_prompt: |
| 56 | + content += "assistant:" |
| 57 | + return content |
| 58 | + |
| 59 | + def __call__(self, text, return_tensors=None, add_special_tokens=False): |
| 60 | + if isinstance(text, list): |
| 61 | + text = "".join(text) |
| 62 | + encoded = list(range(len(text))) |
| 63 | + if return_tensors == "pt": |
| 64 | + return {"input_ids": torch.tensor([encoded], dtype=torch.long)} |
| 65 | + return {"input_ids": encoded} |
| 66 | + |
| 67 | + |
43 | 68 | def test_math_data_processor(): |
44 | 69 | raw_dataset = Dataset.from_list( |
45 | 70 | [ |
@@ -131,6 +156,37 @@ def test_math_hf_data_processor(tokenizer_name, dataset_cls): |
131 | 156 | assert len(first_item["message_log"]) > 0 |
132 | 157 |
|
133 | 158 |
|
| 159 | +def test_math_hf_data_processor_without_prompt(): |
| 160 | + datum_dict = { |
| 161 | + "messages": [ |
| 162 | + {"role": "user", "content": "Solve 1+1."}, |
| 163 | + {"role": "assistant", "content": "2"}, |
| 164 | + ], |
| 165 | + "task_name": "math", |
| 166 | + } |
| 167 | + tokenizer = DummyTokenizer() |
| 168 | + |
| 169 | + math_task_spec = TaskDataSpec( |
| 170 | + task_name="math", |
| 171 | + prompt_file=None, |
| 172 | + system_prompt_file=None, |
| 173 | + ) |
| 174 | + |
| 175 | + result = math_hf_data_processor( |
| 176 | + datum_dict=datum_dict, |
| 177 | + task_data_spec=math_task_spec, |
| 178 | + tokenizer=tokenizer, |
| 179 | + max_seq_length=128, |
| 180 | + idx=0, |
| 181 | + ) |
| 182 | + |
| 183 | + assert result["extra_env_info"]["ground_truth"] == "2" |
| 184 | + assert result["loss_multiplier"] == 1.0 |
| 185 | + assert len(result["message_log"]) == 1 |
| 186 | + assert result["message_log"][0]["role"] == "user" |
| 187 | + assert "Solve 1+1." in result["message_log"][0]["content"] |
| 188 | + |
| 189 | + |
134 | 190 | @pytest.fixture |
135 | 191 | def system_prompt_file(request): |
136 | 192 | with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file: |
|
0 commit comments