Skip to content

Commit 8416915

Browse files
authored
feat: evaluation implement (#16)
Signed-off-by: Yuki Huang <[email protected]>
1 parent 076274c commit 8416915

File tree

12 files changed

+602
-24
lines changed

12 files changed

+602
-24
lines changed

docs/guides/eval.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Evaluation
2+
3+
## Start Evaluation
4+
5+
### Start Script
6+
```sh
7+
# To run the evaluation with default config (examples/configs/eval.yaml)
8+
uv run python examples/run_eval.py
9+
10+
# Specify a custom config file
11+
uv run python examples/run_eval.py --config path/to/custom_config.yaml
12+
13+
# Override specific config values via command line
14+
uv run python examples/run_eval.py generation.model_name="Qwen/Qwen2.5-Math-7B-Instruct"
15+
```
16+
17+
### Example Output
18+
19+
```
20+
============================================================
21+
model_name='Qwen2.5-Math-1.5B-Instruct' dataset_name='aime_2024'
22+
score=0.10 (3.0/30)
23+
============================================================
24+
```
25+
26+
## Configuration
27+
28+
An example Evaluation configuration file can be found [here](../../examples/configs/eval.yaml).
29+
30+
### Prompt Template Configuration
31+
Always remember to use the same `prompt_file` and `system_prompt_file` that were used during training.
32+
33+
For open-source models, we recommend setting `prompt_file=null` and `system_prompt_file=null` to allow them to use their native chat templates.

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cluster.md
1818
adding_new_models.md
1919
guides/sft.md
2020
guides/grpo.md
21+
guides/eval.md
2122
```
2223

2324
```{toctree}

examples/configs/eval.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Evaluation Configuration
2+
generation:
3+
backend: "vllm" # only vllm is supported for evaluation
4+
max_new_tokens: ${generation.vllm_cfg.max_model_len}
5+
temperature: 0.0
6+
top_p: 1.0
7+
top_k: -1 # disable
8+
num_prompts_per_step: -1 # -1 means pass all prompts at once
9+
model_name: "Qwen/Qwen2.5-Math-1.5B-Instruct"
10+
vllm_cfg:
11+
tensor_parallel_size: 1
12+
gpu_memory_utilization: 0.9
13+
max_model_len: 2048
14+
15+
data:
16+
max_input_seq_length: ${generation.vllm_cfg.max_model_len} # useless since we directly use prompts in evaluation
17+
prompt_file: null
18+
system_prompt_file: null
19+
dataset_name: "HuggingFaceH4/aime_2024"
20+
dataset_key: "train"
21+
problem_key: "problem"
22+
solution_key: "answer"
23+
24+
env:
25+
math:
26+
num_workers: 8
27+
28+
cluster:
29+
gpus_per_node: 1
30+
num_nodes: 1

examples/run_eval.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
import pprint
18+
import sys
19+
20+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21+
22+
from datasets import load_dataset
23+
from omegaconf import OmegaConf
24+
from transformers import AutoTokenizer
25+
26+
from examples.run_grpo_math import math_data_processor
27+
from nemo_reinforcer.data import MathDataConfig
28+
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset
29+
from nemo_reinforcer.data.interfaces import TaskDataSpec
30+
from nemo_reinforcer.data.llm_message_utils import remap_dataset_keys
31+
from nemo_reinforcer.distributed.virtual_cluster import init_ray
32+
from nemo_reinforcer.environments.math_environment import MathEnvironment
33+
from nemo_reinforcer.evals.eval import MasterConfig, run_env_eval, setup
34+
from nemo_reinforcer.models.generation.interfaces import GenerationConfig
35+
36+
37+
def parse_args():
38+
"""Parse command line arguments."""
39+
parser = argparse.ArgumentParser(description="Run Evaluation with configuration")
40+
parser.add_argument(
41+
"--config", type=str, default=None, help="Path to YAML config file"
42+
)
43+
44+
# Parse known args for the script
45+
args, remaining = parser.parse_known_args()
46+
47+
# Convert remaining args to OmegaConf format
48+
overrides = OmegaConf.from_dotlist(remaining)
49+
50+
return args, overrides
51+
52+
53+
def setup_data(
54+
data_config: MathDataConfig, generation_config: GenerationConfig, env_configs
55+
):
56+
print("\n▶ Setting up data...")
57+
math_task_spec = TaskDataSpec(
58+
task_name="math",
59+
prompt_file=data_config["prompt_file"],
60+
system_prompt_file=data_config["system_prompt_file"],
61+
)
62+
63+
# load dataset
64+
base_dataset = load_dataset(data_config["dataset_name"])
65+
if data_config["dataset_key"] is not None:
66+
base_dataset = base_dataset[data_config["dataset_key"]]
67+
# remap problem and solution keys
68+
remapped_dataset = remap_dataset_keys(
69+
base_dataset,
70+
mapping_dict={
71+
data_config["problem_key"]: "problem",
72+
data_config["solution_key"]: "expected_answer",
73+
},
74+
)
75+
76+
tokenizer = AutoTokenizer.from_pretrained(generation_config["model_name"])
77+
if tokenizer.pad_token is None:
78+
tokenizer.pad_token = tokenizer.eos_token
79+
80+
math_env = MathEnvironment.options(
81+
runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE}
82+
).remote(env_configs["math"])
83+
84+
dataset = AllTaskProcessedDataset(
85+
dataset=remapped_dataset,
86+
tokenizer=tokenizer,
87+
default_task_data_spec=math_task_spec,
88+
task_data_processors=math_data_processor,
89+
max_seq_length=data_config["max_input_seq_length"],
90+
)
91+
92+
return dataset, math_env, tokenizer
93+
94+
95+
def main():
96+
"""Main entry point."""
97+
# Parse arguments
98+
args, overrides = parse_args()
99+
100+
if not args.config:
101+
args.config = os.path.join(os.path.dirname(__file__), "configs", "eval.yaml")
102+
103+
config = OmegaConf.load(args.config)
104+
print(f"Loaded configuration from: {args.config}")
105+
106+
if overrides:
107+
override_conf = OmegaConf.from_cli()
108+
print(f"Overrides: {override_conf}")
109+
config = OmegaConf.merge(config, override_conf)
110+
111+
config: MasterConfig = OmegaConf.to_container(config, resolve=True)
112+
print("Applied CLI overrides")
113+
114+
# Print config
115+
print("Final config:")
116+
pprint.pprint(config)
117+
118+
# Init ray
119+
init_ray()
120+
121+
# Setup data
122+
(
123+
dataset,
124+
math_env,
125+
tokenizer,
126+
) = setup_data(config["data"], config["generation"], config["env"])
127+
128+
# Setup
129+
(
130+
vllm_generation,
131+
dataloader,
132+
master_config,
133+
) = setup(config, tokenizer, dataset)
134+
135+
# Run evaluation
136+
run_env_eval(
137+
vllm_generation,
138+
dataloader,
139+
math_env,
140+
master_config,
141+
)
142+
143+
144+
if __name__ == "__main__":
145+
main()

examples/run_grpo_math.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def math_data_processor(
122122

123123
template = task_data_spec.custom_template
124124
message_log: LLMMessageLogType = []
125+
126+
# system prompt
125127
if task_data_spec.system_prompt:
126128
sys_message = {"role": "system", "content": task_data_spec.system_prompt}
127129
message = tokenizer.apply_chat_template(
@@ -135,10 +137,11 @@ def math_data_processor(
135137
0
136138
]
137139
message_log.append(sys_message)
138-
user_message = {
139-
"role": "user",
140-
"content": task_data_spec.prompt.format(problem),
141-
}
140+
141+
# user prompt
142+
if task_data_spec.prompt:
143+
problem = task_data_spec.prompt.format(problem)
144+
user_message = {"role": "user", "content": problem}
142145
message = tokenizer.apply_chat_template(
143146
[user_message],
144147
chat_template=template,
@@ -167,8 +170,9 @@ def math_data_processor(
167170
"extra_env_info": extra_env_info,
168171
"loss_multiplier": loss_multiplier,
169172
"idx": idx,
170-
"task_name": datum_dict["task_name"],
171173
}
174+
if "task_name" in datum_dict:
175+
output["task_name"] = datum_dict["task_name"]
172176
return output
173177

174178

nemo_reinforcer/data/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,8 @@ class DataConfig(TypedDict):
2121
system_prompt_file: Optional[str]
2222
dataset_name: str
2323
val_dataset_name: Optional[str]
24+
25+
26+
class MathDataConfig(DataConfig):
27+
problem_key: str
28+
solution_key: str

nemo_reinforcer/data/datasets.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,54 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
130130
batch_max_length=batch_max_length,
131131
)
132132
return output
133+
134+
135+
def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
136+
"""Collate function for evaluation.
137+
138+
Takes a list of data samples and combines them into a single batched dictionary
139+
for model evaluation.
140+
141+
Args:
142+
data_batch: List of data samples with message_log, extra_env_info, and idx fields.
143+
144+
Returns:
145+
BatchedDataDict with message_log, extra_env_info, and idx fields.
146+
147+
Examples:
148+
```{doctest}
149+
>>> import torch
150+
>>> from nemo_reinforcer.data.datasets import eval_collate_fn
151+
>>> from nemo_reinforcer.data.interfaces import DatumSpec
152+
>>> data_batch = [
153+
... DatumSpec(
154+
... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}],
155+
... extra_env_info={'ground_truth': '1'},
156+
... idx=0,
157+
... ),
158+
... DatumSpec(
159+
... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}],
160+
... extra_env_info={'ground_truth': '2'},
161+
... idx=1,
162+
... ),
163+
... ]
164+
>>> output = eval_collate_fn(data_batch)
165+
>>> output['message_log'][0]
166+
[{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}]
167+
>>> output['message_log'][1]
168+
[{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}]
169+
>>> output['extra_env_info']
170+
[{'ground_truth': '1'}, {'ground_truth': '2'}]
171+
>>> output['idx']
172+
[0, 1]
173+
"""
174+
message_log = [datum_spec["message_log"] for datum_spec in data_batch]
175+
extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch]
176+
idx = [datum_spec["idx"] for datum_spec in data_batch]
177+
178+
output = BatchedDataDict(
179+
message_log=message_log,
180+
extra_env_info=extra_env_info,
181+
idx=idx,
182+
)
183+
return output

nemo_reinforcer/data/llm_message_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Dict, List, Union
15-
14+
from typing import Dict, List
1615

1716
import torch
17+
from datasets import Dataset
1818

1919
from nemo_reinforcer.data.interfaces import (
2020
LLMMessageLogType,
@@ -390,3 +390,27 @@ def get_formatted_message_log(
390390
prev_formatted_message = formatted_message
391391

392392
return message_log
393+
394+
395+
def remap_dataset_keys(
396+
dataset: Dataset,
397+
mapping_dict: Dict[str, str],
398+
) -> Dataset:
399+
"""Remap dataset keys as per mapping.
400+
401+
Args:
402+
dataset: The input dataset to remap keys in
403+
mapping_dict: A dictionary mapping input keys to output keys
404+
405+
Returns:
406+
Dataset: A new dataset with remapped keys
407+
"""
408+
# no need to remap if the keys are already correct
409+
if all(k == v for k, v in mapping_dict.items()):
410+
return dataset
411+
412+
# return the remapped dataset
413+
return dataset.map(
414+
lambda x: {v: x[k] for k, v in mapping_dict.items()},
415+
remove_columns=list(mapping_dict.keys()),
416+
)

0 commit comments

Comments
 (0)