Skip to content

Commit 790888f

Browse files
yuki-97terrykong
andauthored
feat: improve eval (#325)
Signed-off-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Terry Kong <terryk@nvidia.com>
1 parent bc8cb65 commit 790888f

File tree

8 files changed

+182
-54
lines changed

8 files changed

+182
-54
lines changed

README.md

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
- [DPO](#dpo)
1515
- [DPO Single Node](#dpo-single-node)
1616
- [DPO Multi-node](#dpo-multi-node)
17+
- [Evaluation](#evaluation)
18+
- [Convert Model Format (Optional)](#convert-model-format-optional)
19+
- [Run Evaluation](#run-evaluation)
1720
- [Set Up Clusters](#set-up-clusters)
1821
- [Citation](#citation)
1922
- [Contributing](#contributing)
@@ -241,7 +244,7 @@ uv run python examples/run_dpo.py \
241244
logger.wandb.name="llama-dpo-sft"
242245
```
243246

244-
Refer to [dpo.yaml](../examples/configs/dpo.yaml) for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md).
247+
Refer to `examples/configs/dpo.yaml` for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md).
245248

246249
### DPO Multi-node
247250

@@ -266,6 +269,52 @@ sbatch \
266269
ray.sub
267270
```
268271

272+
## Evaluation
273+
274+
We provide evaluation tools to assess model capabilities.
275+
276+
### Convert Model Format (Optional)
277+
278+
If you have trained a model and saved the checkpoint in the Pytorch DCP format, you first need to convert it to the Hugging Face format before running evaluation:
279+
280+
```sh
281+
# Example for a GRPO checkpoint at step 170
282+
uv run python examples/convert_dcp_to_hf.py \
283+
--config results/grpo/step_170/config.yaml \
284+
--dcp-ckpt-path results/grpo/step_170/policy/weights/ \
285+
--hf-ckpt-path results/grpo/hf
286+
```
287+
> **Note:** Adjust the paths according to your training output directory structure.
288+
289+
For an in-depth explanation of checkpointing, refer to the [Checkpointing documentation](docs/design-docs/checkpointing.md).
290+
291+
### Run Evaluation
292+
293+
Run evaluation script with converted model:
294+
295+
```sh
296+
uv run python examples/run_eval.py generation.model_name=$PWD/results/grpo/hf
297+
```
298+
299+
Run evaluation script with custom settings:
300+
301+
```sh
302+
# Example: Evaluation of DeepScaleR-1.5B-Preview on MATH-500 using 8 GPUs
303+
# Pass@1 accuracy averaged over 16 samples for each problem
304+
uv run python examples/run_eval.py \
305+
generation.model_name=agentica-org/DeepScaleR-1.5B-Preview \
306+
generation.temperature=0.6 \
307+
generation.top_p=0.95 \
308+
generation.vllm_cfg.max_model_len=32768 \
309+
data.dataset_name=HuggingFaceH4/MATH-500 \
310+
data.dataset_key=test \
311+
eval.num_tests_per_prompt=16 \
312+
cluster.gpus_per_node=8
313+
```
314+
> **Note:** Evaluation results may vary slightly due to various factors, such as sampling parameters, random seed, inference engine version, and inference engine settings.
315+
316+
Refer to `examples/configs/eval.yaml` for a full list of parameters that can be overridden. For an in-depth explanation of evaluation, refer to the [Evaluation documentation](docs/guides/eval.md).
317+
269318
## Set Up Clusters
270319

271320
For detailed instructions on how to set up and launch NeMo RL on Slurm or Kubernetes clusters, please refer to the dedicated [Cluster Start](docs/cluster.md) documentation.

docs/guides/dpo.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,7 @@ The DPO implementation in NeMo RL supports several key parameters that can be ad
167167
- `dpo.sft_average_log_probs`: Whether to average log probabilities over tokens in the SFT loss term
168168

169169
These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case.
170+
171+
## Evaluate the Trained Model
172+
173+
Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.

docs/guides/eval.md

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,66 +2,81 @@
22

33
This document explains how to use an evaluation script for assessing model capabilities.
44

5-
## Start Evaluation
5+
## Prepare for Evaluation
66

7-
To run the evaluation, you can use the default configuration file or specify a custom one.
7+
To prepare for evaluation, first ensure your model is in the correct format, which may involve an optional conversion of PyTorch DCP checkpoints to the Hugging Face format. Following this, you need to prepare the evaluation configuration, which includes defining prompt templates and any custom settings required to run the evaluation.
88

9-
### Start Script
9+
### Convert DCP to HF (Optional)
10+
If you have trained a model and saved the checkpoint in the Pytorch DCP format, you first need to convert it to the Hugging Face format before running evaluation.
1011

11-
**Evaluate Standard Models:**
12-
13-
To run evaluation using a model directly from Hugging Face Hub or a local path already in HF format, use the `run_eval.py` script.
12+
Use the `examples/convert_dcp_to_hf.py` script. You'll need the path to the training configuration file (`config.yaml`), the DCP checkpoint directory, and specify an output path for the HF format model.
1413

1514
```sh
16-
# To run the evaluation with default config (examples/configs/eval.yaml)
17-
uv run python examples/run_eval.py
15+
# Example for a GRPO checkpoint at step 170
16+
uv run python examples/convert_dcp_to_hf.py \
17+
--config results/grpo/step_170/config.yaml \
18+
--dcp-ckpt-path results/grpo/step_170/policy/weights/ \
19+
--hf-ckpt-path results/grpo/hf
20+
```
21+
> **Note:** Adjust the paths according to your training output directory structure.
1822
19-
# Specify a custom config file
20-
uv run python examples/run_eval.py --config path/to/custom_config.yaml
23+
Once the conversion is complete, you can override the `generation.model_name` to point to the directory containing the converted HF model in [this section](#run-the-evaluation-script).
2124

22-
# Override specific config values via command line (e.g., model name)
23-
uv run python examples/run_eval.py generation.model_name="Qwen/Qwen2.5-Math-7B-Instruct"
24-
```
25+
### Prepare the Evaluation Configuration
26+
**Override with Custom Settings**
2527

26-
**Evaluate Models Trained with DCP Checkpoints (GRPO/SFT):**
28+
To run the evaluation, you can use the [default configuration file](../../examples/configs/eval.yaml). Alternatively, you can specify a custom one or override some settings via the command line.
2729

28-
If you have trained a model using GRPO or SFT and saved the checkpoint in the Pytorch DCP format, you first need to convert it to the Hugging Face format before running evaluation.
30+
The default configuration employs greedy sampling to evaluate Qwen2.5-Math-1.5B-Instruct on AIME-2024.
2931

30-
1. **Convert DCP to HF:**
31-
Use the `examples/convert_dcp_to_hf.py` script. You'll need the path to the training configuration file (`config.yaml`), the DCP checkpoint directory, and specify an output path for the HF format model.
32+
**Prompt Template Configuration**
3233

33-
```sh
34-
# Example for a GRPO checkpoint at step 170
35-
uv run python examples/convert_dcp_to_hf.py \
36-
--config results/grpo/step_170/config.yaml \
37-
--dcp-ckpt-path results/grpo/step_170/policy/weights/ \
38-
--hf-ckpt-path results/grpo/hf
39-
```
40-
*Note: Adjust the paths according to your training output directory structure.*
34+
Always remember to use the same prompt and chat_template that were used during training.
4135

42-
2. **Run Evaluation on Converted Model:**
43-
Once the conversion is complete, run the evaluation script, overriding the `generation.model_name` to point to the directory containing the converted HF model.
36+
For open-source models, we recommend setting `tokenizer.chat_template=default`, `data.prompt_file=null` and `data.system_prompt_file=null` to allow them to use their native chat templates.
4437

45-
```sh
46-
# Example using the converted HF model from the previous step
47-
uv run python examples/run_eval.py generation.model_name=$PWD/results/grpo/hf
48-
```
38+
## Run the Evaluation Script
4939

50-
### Example Output
40+
We will use the `run_eval.py` script to run an evaluation using a model directly from the Hugging Face Hub or from a local path that is already in Hugging Face format.
5141

42+
Note that the evaluation script only supports the Hugging Face format model. If you haven't converted your DCP format model, you should back to [Convert DCP to HF](#convert-dcp-to-hf-optional) and follow the guide to convert your model.
43+
44+
```sh
45+
# Run evaluation script with default config (examples/configs/eval.yaml)
46+
uv run python examples/run_eval.py
47+
48+
# Run evaluation script with converted model
49+
uv run python examples/run_eval.py generation.model_name=$PWD/results/grpo/hf
50+
51+
# Run evaluation script with custom config file
52+
uv run python examples/run_eval.py --config path/to/custom_config.yaml
53+
54+
# Override specific config values via command line
55+
# Example: Evaluation of DeepScaleR-1.5B-Preview on MATH-500 using 8 GPUs
56+
# Pass@1 accuracy averaged over 16 samples for each problem
57+
uv run python examples/run_eval.py \
58+
generation.model_name=agentica-org/DeepScaleR-1.5B-Preview \
59+
generation.temperature=0.6 \
60+
generation.top_p=0.95 \
61+
generation.vllm_cfg.max_model_len=32768 \
62+
data.dataset_name=HuggingFaceH4/MATH-500 \
63+
data.dataset_key=test \
64+
eval.num_tests_per_prompt=16 \
65+
cluster.gpus_per_node=8
5266
```
53-
============================================================
54-
model_name='Qwen2.5-Math-1.5B-Instruct' dataset_name='aime_2024'
55-
score=0.10 (3.0/30)
56-
============================================================
57-
```
67+
> **Note:** Evaluation results may vary slightly due to various factors, such as sampling parameters, random seed, inference engine version, and inference engine settings.
5868
59-
## Example Configuration File
69+
## Example Evaluation Output
6070

61-
You can find an example evaluation configuration file [here](../../examples/configs/eval.yaml).
71+
When you complete the evaluation, you will receive a summary similar to the following.
6272

63-
### Prompt Template Configuration
73+
```
74+
============================================================
75+
model_name='Qwen2.5-Math-1.5B-Instruct' dataset_name='aime_2024'
76+
max_new_tokens=2048 temperature=0.0 top_p=1.0 top_k=-1
6477
65-
Always remember to use the same `prompt_file` and `system_prompt_file` that were used during training.
78+
metric='pass@1' num_tests_per_prompt=1
6679
67-
For open-source models, we recommend setting `prompt_file=null` and `system_prompt_file=null` to allow them to use their native chat templates.
80+
score=0.1000 (3.0/30)
81+
============================================================
82+
```

docs/guides/grpo.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,7 @@ $$
181181
By multiplying the first term of the loss function by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$.
182182

183183
To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO.
184+
185+
## Evaluate the Trained Model
186+
187+
Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.

docs/guides/sft.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,8 @@ NeMo RL SFT uses Hugging Face chat templates to format the individual examples.
7575
By default, NeMo RL has support for `Squad` and `OpenAssistant` datasets. Both of these datasets are downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.
7676

7777
Adding a new dataset is a straightforward process.
78-
As long as your custom dataset has the `formatted_ds` and `task_spec` attributes described above, it can serve as a drop-in replacement for Squad and OpenAssistant.
78+
As long as your custom dataset has the `formatted_ds` and `task_spec` attributes described above, it can serve as a drop-in replacement for Squad and OpenAssistant.
79+
80+
## Evaluate the Trained Model
81+
82+
Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.

examples/configs/eval.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# Evaluation Configuration
2+
eval:
3+
metric: "pass@1" # only pass@1 is supported now
4+
num_tests_per_prompt: 1 # every prompt will be tested num_tests_per_prompt times and use the average score as the final score
5+
seed: 42
6+
27
generation:
38
backend: "vllm" # only vllm is supported for evaluation
49
max_new_tokens: ${generation.vllm_cfg.max_model_len}
510
temperature: 0.0
611
top_p: 1.0
7-
top_k: -1 # disable
12+
top_k: -1 # -1 means disable
813
num_prompts_per_step: -1 # -1 means pass all prompts at once
914
model_name: "Qwen/Qwen2.5-Math-1.5B-Instruct"
1015
stop_token_ids: null

nemo_rl/evals/eval.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.utils.data import DataLoader
2020
from transformers import AutoTokenizer
2121

22+
from nemo_rl.algorithms.utils import set_seed
2223
from nemo_rl.data import MathDataConfig
2324
from nemo_rl.data.datasets import AllTaskProcessedDataset, eval_collate_fn
2425
from nemo_rl.data.llm_message_utils import get_keys_from_message_log
@@ -33,7 +34,14 @@
3334
# ===============================================================================
3435

3536

37+
class EvalConfig(TypedDict):
38+
metric: str
39+
num_tests_per_prompt: int
40+
seed: int
41+
42+
3643
class MasterConfig(TypedDict):
44+
eval: EvalConfig
3745
generate: GenerationConfig
3846
data: MathDataConfig
3947
env: MathEnvConfig
@@ -66,9 +74,25 @@ def setup(
6674
VLLM model, data loader, and config.
6775
"""
6876
# Extract individual configs for easier access
77+
eval_config = master_config["eval"]
6978
generation_config = master_config["generation"]
7079
cluster_config = master_config["cluster"]
7180

81+
# Set seed for reproducibility
82+
set_seed(eval_config["seed"])
83+
84+
# Check settings
85+
metric = eval_config["metric"]
86+
num_tests_per_prompt = eval_config["num_tests_per_prompt"]
87+
temperature = generation_config["temperature"]
88+
top_k = generation_config["top_k"]
89+
# TODO @yukih: support pass@k and cons@k
90+
assert metric in ["pass@1"], f"Invalid metric: {metric}"
91+
if num_tests_per_prompt > 1:
92+
assert temperature > 0 and top_k != 1, (
93+
"temperature > 0 and top_k != 1 are required for multiple samples"
94+
)
95+
7296
# ==========================
7397
# Data
7498
# ==========================
@@ -137,15 +161,29 @@ def run_env_eval(vllm_generation, dataloader, env, master_config):
137161
env: Environment that scores responses.
138162
master_config: Configuration settings.
139163
"""
164+
# Extract for easier access
165+
generation_config = master_config["generation"]
166+
eval_config = master_config["eval"]
167+
metric = eval_config["metric"]
168+
num_tests_per_prompt = eval_config["num_tests_per_prompt"]
169+
140170
# Run evaluation loop
141171
score, count = 0.0, 0
142172
for batch in dataloader:
173+
# update stats
174+
count += batch.size * num_tests_per_prompt
175+
176+
# measure multiple samples
177+
if num_tests_per_prompt > 1:
178+
batch = batch.repeat_interleave(num_tests_per_prompt)
179+
143180
# get input prompt from message_log
144181
prompts = []
145182
for message_log in batch["message_log"]:
146183
content = [message["content"] for message in message_log]
147184
content = "\n".join(content)
148185
prompts.append(content)
186+
149187
# generate by vllm
150188
inputs = BatchedDataDict({"prompts": prompts})
151189
outputs = vllm_generation.generate_text(inputs)["texts"]
@@ -166,19 +204,28 @@ def run_env_eval(vllm_generation, dataloader, env, master_config):
166204
]
167205
env_return = ray.get(env.step.remote(to_env, batch["extra_env_info"]))
168206

169-
score += env_return.rewards.sum().item()
170-
count += len(env_return.rewards)
207+
# update stats
208+
if metric == "pass@1":
209+
score += env_return.rewards.sum().item()
210+
else:
211+
raise ValueError(f"Invalid metric: {metric}")
171212

172213
# Cleanup before printing results
173214
ray.get(env.shutdown.remote())
174215
vllm_generation.shutdown()
175216

176217
# Print results
177218
dataset_name = os.path.basename(master_config["data"]["dataset_name"])
178-
model_name = os.path.basename(master_config["generation"]["model_name"])
219+
model_name = os.path.basename(generation_config["model_name"])
220+
max_new_tokens = generation_config["vllm_cfg"]["max_model_len"]
221+
temperature = generation_config["temperature"]
222+
top_p = generation_config["top_p"]
223+
top_k = generation_config["top_k"]
179224
average_score = score / count
180225

181226
print("\n" + "=" * 60)
182227
print(f"{model_name=} {dataset_name=}")
183-
print(f"score={average_score:.2f} ({score}/{count})")
228+
print(f"{max_new_tokens=} {temperature=} {top_p=} {top_k=}\n")
229+
print(f"{metric=} {num_tests_per_prompt=}\n")
230+
print(f"score={average_score:.4f} ({score}/{count})")
184231
print("=" * 60 + "\n")

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -800,14 +800,14 @@ def test_vllm_weight_update_memory(cluster, tokenizer, enable_dtensor):
800800
# Check memory stats
801801
assert current_allocated == 0.0, "Memory should be 0 after refit completed"
802802
assert current_reserved == 0.0, "Memory should be 0 after refit completed"
803-
# memory threshold: memory during non-streaming weight update on 1B model on 2 GPUs
803+
# memory threshold: memory during non-streaming weight update on 0.6B model on 2 GPUs
804804
# memory during streaming weight update should less than this baseline threshold
805805
if enable_dtensor:
806-
assert peak_allocated < 8074, "Peak allocated memory should < 8074 MB"
807-
assert peak_reserved < 8088, "Peak reserved memory should < 8088 MB"
806+
assert peak_allocated < 4005, "Peak allocated memory should < 4005 MB"
807+
assert peak_reserved < 4016, "Peak reserved memory should < 4016 MB"
808808
else:
809-
assert peak_allocated < 11286, "Peak allocated memory should < 11286 MB"
810-
assert peak_reserved < 11298, "Peak reserved memory should < 11298 MB"
809+
assert peak_allocated < 5736, "Peak allocated memory should < 5736 MB"
810+
assert peak_reserved < 5748, "Peak reserved memory should < 5748 MB"
811811

812812
# Clean up
813813
vllm_policy.shutdown()

0 commit comments

Comments
 (0)