Skip to content

Commit bbdd671

Browse files
ashors1yfw
andauthored
feat: DPO (#180)
Signed-off-by: ashors1 <ashors@nvidia.com> Signed-off-by: Anna Shors <ashors@nvidia.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-authored-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent 88bc0fd commit bbdd671

File tree

27 files changed

+2537
-247
lines changed

27 files changed

+2537
-247
lines changed

.github/workflows/cicd-main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ jobs:
170170
if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then
171171
uv run --no-sync bash ./tests/functional/sft.sh
172172
uv run --no-sync bash ./tests/functional/grpo.sh
173+
uv run --no-sync bash ./tests/functional/dpo.sh
173174
else
174175
echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }}
175176
fi

README.md

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
- [Features](#features)
66
- [Prerequisuites](#prerequisuites)
77
- [Quick start](#quick-start)
8-
- [SFT](#sft)
8+
- [GRPO](#grpo)
99
- [Single Node](#single-node)
1010
- [Multi-node](#multi-node)
11-
- [GRPO](#grpo)
11+
- [SFT](#sft)
1212
- [Single Node](#single-node-1)
1313
- [Multi-node](#multi-node-1)
14+
- [DPO](#dpo)
15+
- [Single Node](#single-node-2)
16+
- [Multi-node](#multi-node-2)
1417
- [Cluster Start](#cluster-start)
1518

1619
**Nemo-Reinforcer** is a scalable and efficient post-training library designed for models ranging from 1 GPU to thousands, and from tiny to over 100 billion parameters.
@@ -33,10 +36,10 @@ What you can expect:
3336
-**Environment Support** - Support for multi-environment training.
3437
-**Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning)
3538
-**Worker Isolation** - Process isolation between RL Actors (no worries about global state)
39+
-**DPO Algorithm** - Direct Preference Optimization for alignment
3640
- 🔜 **Larger Model Support** - Native PyTorch support for models up to 70B parameters
3741
- 🔜 **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training
3842
- 🔜 **Environment Isolation** - Dependency isolation between components
39-
- 🔜 **DPO Algorithm** - Direct Preference Optimization for alignment
4043

4144
## Prerequisuites
4245

@@ -59,6 +62,61 @@ pip install uv
5962

6063
**Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models.
6164

65+
### GRPO
66+
67+
We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset.
68+
69+
#### Single Node
70+
71+
To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`:
72+
73+
```sh
74+
# Run the GRPO math example using a 1B parameter model
75+
uv run python examples/run_grpo_math.py
76+
```
77+
78+
By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus,
79+
80+
```sh
81+
# Run the GRPO math example using a 1B parameter model using 8 GPUs
82+
uv run python examples/run_grpo_math.py \
83+
cluster.gpus_per_node=8
84+
```
85+
86+
You can override any of the parameters listed in the yaml configuration file. For example,
87+
88+
```sh
89+
uv run python examples/run_grpo_math.py \
90+
policy.model_name="Qwen/Qwen2-1.5B" \
91+
checkpointing.checkpoint_dir="results/qwen1_5b_math" \
92+
logger.wandb_enabled=True \
93+
logger.wandb.name="grpo-qwen1_5b_math" \
94+
logger.num_val_samples_to_print=10 \
95+
```
96+
97+
#### Multi-node
98+
99+
```sh
100+
# Run from the root of NeMo-Reinforcer repo
101+
NUM_ACTOR_NODES=2
102+
# Add a timestamp to make each job name unique
103+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
104+
105+
# grpo_math_8b uses Llama-3.1-8B-Instruct model
106+
COMMAND="uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \
107+
UV_CACHE_DIR=YOUR_UV_CACHE_DIR \
108+
CONTAINER=YOUR_CONTAINER \
109+
MOUNTS="$PWD:$PWD" \
110+
sbatch \
111+
--nodes=${NUM_ACTOR_NODES} \
112+
--account=YOUR_ACCOUNT \
113+
--job-name=YOUR_JOBNAME \
114+
--partition=YOUR_PARTITION \
115+
--time=4:0:0 \
116+
--gres=gpu:8 \
117+
ray.sub
118+
```
119+
62120
### SFT
63121

64122
We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/).
@@ -87,15 +145,12 @@ Refer to `examples/configs/sft.yaml` for a full list of parameters that can be o
87145

88146
#### Multi-node
89147

90-
For distributed training across multiple nodes:
91-
92148
```sh
93149
# Run from the root of NeMo-Reinforcer repo
94150
NUM_ACTOR_NODES=2
95151
# Add a timestamp to make each job name unique
96152
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
97153

98-
# SFT experiment uses Llama-3.1-8B model
99154
COMMAND="uv run ./examples/run_sft.py --config examples/configs/sft.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/sft_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='sft-llama8b'" \
100155
CONTAINER=YOUR_CONTAINER \
101156
MOUNTS="$PWD:$PWD" \
@@ -109,48 +164,55 @@ sbatch \
109164
ray.sub
110165
```
111166

112-
### GRPO
167+
### DPO
113168

114-
We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset.
169+
We provide a sample DPO experiment that uses the [HelpSteer3 dataset](https://huggingface.co/datasets/nvidia/HelpSteer3) for preference-based training.
115170

116171
#### Single Node
117172

118-
To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`:
173+
The default DPO experiment is configured to run on a single GPU. To launch the experiment:
119174

120175
```sh
121-
# Run the GRPO math example using a 1B parameter model
122-
uv run python examples/run_grpo_math.py
176+
uv run python examples/run_dpo.py
123177
```
124178

125-
By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus,
179+
This trains `Llama3.2-1B-Instruct` on one GPU.
180+
181+
If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama3.1 Instruct model:
126182

127183
```sh
128-
# Run the GRPO math example using a 1B parameter model using 8 GPUs
129-
uv run python examples/run_grpo_math.py \
184+
uv run python examples/run_dpo.py \
185+
policy.model_name="meta-llama/Llama-3.1-8B-Instruct" \
186+
policy.train_global_batch_size=256 \
130187
cluster.gpus_per_node=8
131188
```
132189

133-
You can override any of the parameters listed in the yaml configuration file. For example,
190+
Any of the DPO parameters can be customized from the command line. For example:
134191

135192
```sh
136-
uv run python examples/run_grpo_math.py \
137-
policy.model_name="Qwen/Qwen2-1.5B" \
138-
checkpointing.checkpoint_dir="results/qwen1_5b_math" \
193+
uv run python examples/run_dpo.py \
194+
dpo.sft_loss_weight=0.1 \
195+
dpo.preference_average_log_probs=True \
196+
checkpointing.checkpoint_dir="results/llama_dpo_sft" \
139197
logger.wandb_enabled=True \
140-
logger.wandb.name="grpo-qwen1_5b_math" \
141-
logger.num_val_samples_to_print=10 \
198+
logger.wandb.name="llama-dpo-sft"
142199
```
143200

201+
Refer to [dpo.yaml](examples/configs/dpo.yaml) for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md).
202+
144203
#### Multi-node
145204

205+
For distributed DPO training across multiple nodes, modify the following script for your use case:
206+
146207
```sh
147208
# Run from the root of NeMo-Reinforcer repo
209+
## number of nodes to use for your job
148210
NUM_ACTOR_NODES=2
149211
# Add a timestamp to make each job name unique
150212
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
151213

152-
# grpo_math_8b uses Llama-3.1-8B-Instruct model
153-
COMMAND="uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \
214+
COMMAND="uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 dpo.val_global_batch_size=32 checkpointing.checkpoint_dir='results/dpo_llama81_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama1b'" \
215+
RAY_DEDUP_LOGS=0 \
154216
CONTAINER=YOUR_CONTAINER \
155217
MOUNTS="$PWD:$PWD" \
156218
sbatch \

docs/guides/dpo.md

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Direct Preference Optimization in Reinforcer
2+
3+
[Direct Preference Optimization (DPO)](https://arxiv.org/pdf/2305.18290) is an RL-free alignment algorithm that operates on preference data. Given a prompt and a pair of chosen and rejected responses, DPO aims
4+
to increase the probability of the chosen response and decrease the probability of the rejected response relative to a frozen reference model. The actor is initialized using the reference model. For more details, refer to the
5+
[DPO paper](https://arxiv.org/pdf/2305.18290).
6+
7+
## Launch a DPO Run
8+
9+
The script [examples/run_dpo.py](../../examples/run_dpo.py) can be used to launch a DPO experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md).
10+
11+
Be sure to launch the job using `uv`. The command to launch a DPO job is as follows:
12+
```bash
13+
uv run examples/run_dpo.py --config <PATH TO YAML CONFIG> <OVERRIDES>
14+
```
15+
If not specified, `config` will default to [examples/configs/dpo.yaml](../../examples/configs/dpo.yaml).
16+
17+
## Configuration
18+
19+
Reinforcer allows users to configure DPO experiments using `yaml` config files. An example DPO configuration file can be found [here](../../examples/configs/dpo.yaml).
20+
21+
To override a value in the config, either update the value in the `yaml` file directly, or pass the override via the command line. For example:
22+
23+
```bash
24+
uv run examples/run_dpo.py \
25+
cluster.gpus_per_node=8 \
26+
dpo.sft_loss_weight=0.1 \
27+
dpo.preference_average_log_probs=True \
28+
logger.wandb.name="dpo-dev-8-gpu"
29+
```
30+
31+
**Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models.
32+
33+
## Datasets
34+
35+
Each class representing a Reinforcer DPO dataset is expected to have the following attributes:
36+
1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below.
37+
2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset.
38+
39+
DPO datasets are expected to follow a specific format with three key fields:
40+
- `prompt`: The input prompt/context
41+
- `chosen_response`: The preferred/winning response
42+
- `rejected_response`: The non-preferred/losing response
43+
44+
[data/hf_datasets/helpsteer3.py](../../nemo_reinforcer/data/hf_datasets/helpsteer3.py) provides an example of how to format data for DPO:
45+
46+
```python
47+
def format_helpsteer3(data):
48+
response_1 = data["response1"]
49+
response_2 = data["response2"]
50+
overall_preference = data["overall_preference"]
51+
52+
if overall_preference < 0:
53+
chosen = response_1
54+
rejected = response_2
55+
elif overall_preference == 0:
56+
chosen = response_1
57+
rejected = response_1
58+
else:
59+
chosen = response_2
60+
rejected = response_1
61+
62+
return {
63+
"prompt": data["context"],
64+
"chosen_response": chosen,
65+
"rejected_response": rejected,
66+
}
67+
```
68+
69+
We also provide a [DPODataset](../../nemo_reinforcer/data/hf_datasets/dpo.py) class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys.
70+
71+
## Adding Custom DPO Datasets
72+
73+
Adding a new DPO dataset is straightforward. Your custom dataset class should:
74+
1. Implement the required format conversion in the constructor
75+
2. Set up the appropriate `task_spec`
76+
77+
Here's a minimal example which simply re-keys an existing jsonl dataset:
78+
79+
```{testcode}
80+
from datasets import load_dataset
81+
from nemo_reinforcer.data.interfaces import TaskDataSpec
82+
from docs.helpers import make_dpo_dataset
83+
84+
class CustomDPODataset:
85+
def preprocess_dataset(
86+
self,
87+
data,
88+
prompt_key: str = "context",
89+
chosen_key: str = "chosen",
90+
rejected_key: str = "rejected"
91+
):
92+
return {
93+
"prompt": data[prompt_key],
94+
"chosen_response": data[chosen_key],
95+
"rejected_response": data[rejected_key],
96+
}
97+
98+
def __init__(
99+
self,
100+
train_data_path: str,
101+
val_data_path: str,
102+
prompt_key: str,
103+
chosen_key: str,
104+
rejected_key: str,
105+
):
106+
# Load and format your dataset
107+
fn_kwargs={
108+
"prompt_key": prompt_key,
109+
"chosen_key": chosen_key,
110+
"rejected_key": rejected_key
111+
}
112+
formatted_ds = {
113+
"train": load_dataset("json", data_files=train_data_path, split="train").map(
114+
self.preprocess_dataset,
115+
fn_kwargs=fn_kwargs,
116+
),
117+
"validation": load_dataset("json", data_files=val_data_path, split="train").map(
118+
self.preprocess_dataset,
119+
fn_kwargs=fn_kwargs,
120+
),
121+
}
122+
123+
# Initialize task spec with dataset name
124+
self.task_spec = TaskDataSpec(
125+
task_name="custom_dpo",
126+
)
127+
self.formatted_ds = formatted_ds
128+
129+
# Create temporary files using helper function
130+
train_file, val_file = make_dpo_dataset()
131+
132+
# Initialize dataset
133+
dataset = CustomDPODataset(
134+
train_data_path=train_file.name,
135+
val_data_path=val_file.name,
136+
prompt_key="context",
137+
chosen_key="chosen",
138+
rejected_key="rejected"
139+
)
140+
141+
# Test dataset properties
142+
print(f"Task name: {dataset.task_spec.task_name}")
143+
print(f"Train examples: {len(dataset.formatted_ds['train'])}")
144+
print(f"Validation examples: {len(dataset.formatted_ds['validation'])}")
145+
print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}")
146+
print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}")
147+
print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}")
148+
```
149+
150+
```{testoutput}
151+
Task name: custom_dpo
152+
Train examples: 2
153+
Validation examples: 2
154+
First train example prompt: What is 2+2?
155+
First train example chosen response: 4
156+
First train example rejected response: 5
157+
```
158+
159+
## DPO-Specific Parameters
160+
161+
The DPO implementation in Reinforcer supports several key parameters that can be adjusted:
162+
163+
- `dpo.reference_policy_kl_penalty`: Controls the strength of the KL penalty term
164+
- `dpo.preference_loss_weight`: Weight for the preference loss
165+
- `dpo.sft_loss_weight`: Weight for the auxiliary SFT loss
166+
- `dpo.preference_average_log_probs`: Whether to average log probabilities over tokens in the preference loss term
167+
- `dpo.sft_average_log_probs`: Whether to average log probabilities over tokens in the SFT loss term
168+
169+
These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case.

docs/helpers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tempfile
16+
import json
17+
18+
19+
def make_dpo_dataset():
20+
train_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
21+
val_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
22+
23+
# Write train data
24+
train_data = [
25+
{"context": "What is 2+2?", "chosen": "4", "rejected": "5"},
26+
{"context": "What is 3*3?", "chosen": "9", "rejected": "6"},
27+
]
28+
for item in train_data:
29+
lines = train_file.write(json.dumps(item) + "\n")
30+
train_file.flush()
31+
32+
# Write validation data
33+
val_data = [
34+
{"context": "What is 4+4?", "chosen": "8", "rejected": "7"},
35+
{"context": "What is 5*5?", "chosen": "25", "rejected": "20"},
36+
]
37+
for item in val_data:
38+
lines = val_file.write(json.dumps(item) + "\n")
39+
val_file.flush()
40+
41+
return train_file, val_file

0 commit comments

Comments
 (0)