Skip to content

Commit cbd4b93

Browse files
jveronvialardodelalleauterrykong
authored
feat: preference datasets (#673)
Signed-off-by: Julien Veron Vialard <jveronvialar@nvidia.com> Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Co-authored-by: Terry Kong <terryk@nvidia.com>
1 parent c4fd5d3 commit cbd4b93

File tree

21 files changed

+1021
-456
lines changed

21 files changed

+1021
-456
lines changed

docs/guides/dpo.md

Lines changed: 75 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -32,129 +32,89 @@ uv run examples/run_dpo.py \
3232

3333
## Datasets
3434

35-
Each class representing a NeMo RL DPO dataset is expected to have the following attributes:
36-
1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below.
37-
2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset.
38-
39-
DPO datasets are expected to follow a specific format with three key fields:
40-
- `prompt`: The input prompt/context
41-
- `chosen_response`: The preferred/winning response
42-
- `rejected_response`: The non-preferred/losing response
43-
44-
[data/hf_datasets/helpsteer3.py](../../nemo_rl/data/hf_datasets/helpsteer3.py) provides an example of how to format data for DPO:
45-
46-
```python
47-
def format_helpsteer3(data):
48-
response_1 = data["response1"]
49-
response_2 = data["response2"]
50-
overall_preference = data["overall_preference"]
51-
52-
if overall_preference < 0:
53-
chosen = response_1
54-
rejected = response_2
55-
elif overall_preference == 0:
56-
chosen = response_1
57-
rejected = response_1
58-
else:
59-
chosen = response_2
60-
rejected = response_1
61-
62-
return {
63-
"prompt": data["context"],
64-
"chosen_response": chosen,
65-
"rejected_response": rejected,
35+
Each DPO dataset class is expected to have the following attributes:
36+
1. `formatted_ds`: The dictionary of formatted datasets, where each dataset should be formatted like
37+
```json
38+
{
39+
"context": [], // list of dicts - The prompt message (including previous turns, if any)
40+
"completions": [ // list of dicts — The list of completions
41+
{
42+
"rank": 0, // int — The rank of the completion (lower rank is preferred)
43+
"completion": [] // list of dicts — The completion message(s)
44+
},
45+
{
46+
"rank": 1, // int — The rank of the completion (lower rank is preferred)
47+
"completion": [] // list of dicts — The completion message(s)
6648
}
49+
]
50+
}
6751
```
52+
2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset.
6853

69-
We also provide a [DPODataset](../../nemo_rl/data/hf_datasets/dpo.py) class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys.
70-
71-
## Adding Custom DPO Datasets
72-
73-
Adding a new DPO dataset is straightforward. Your custom dataset class should:
74-
1. Implement the required format conversion in the constructor
75-
2. Set up the appropriate `task_spec`
76-
77-
Here's a minimal example which simply re-keys an existing jsonl dataset:
78-
79-
```{testcode}
80-
from datasets import load_dataset
81-
from nemo_rl.data.interfaces import TaskDataSpec
82-
from docs.helpers import make_dpo_dataset
83-
84-
class CustomDPODataset:
85-
def preprocess_dataset(
86-
self,
87-
data,
88-
prompt_key: str = "context",
89-
chosen_key: str = "chosen",
90-
rejected_key: str = "rejected"
91-
):
92-
return {
93-
"prompt": data[prompt_key],
94-
"chosen_response": data[chosen_key],
95-
"rejected_response": data[rejected_key],
54+
DPO training supports only two completions (where the lowest rank is preferred and the highest one is rejected), with each completion being a single response. For example:
55+
```json
56+
{
57+
"context": [
58+
{
59+
"role": "user",
60+
"content": "What's the capital of France?"
61+
},
62+
{
63+
"role": "assistant",
64+
"content": "The capital of France is Paris."
65+
},
66+
{
67+
"role": "user",
68+
"content": "Thanks! And what's the capital of Germany?"
9669
}
97-
98-
def __init__(
99-
self,
100-
train_data_path: str,
101-
val_data_path: str,
102-
prompt_key: str,
103-
chosen_key: str,
104-
rejected_key: str,
105-
):
106-
# Load and format your dataset
107-
fn_kwargs={
108-
"prompt_key": prompt_key,
109-
"chosen_key": chosen_key,
110-
"rejected_key": rejected_key
111-
}
112-
formatted_ds = {
113-
"train": load_dataset("json", data_files=train_data_path, split="train").map(
114-
self.preprocess_dataset,
115-
fn_kwargs=fn_kwargs,
116-
),
117-
"validation": load_dataset("json", data_files=val_data_path, split="train").map(
118-
self.preprocess_dataset,
119-
fn_kwargs=fn_kwargs,
120-
),
70+
],
71+
"completions": [
72+
{
73+
"rank": 0,
74+
"completion": [
75+
{
76+
"role": "assistant",
77+
"content": "The capital of Germany is Berlin."
78+
}
79+
]
80+
},
81+
{
82+
"rank": 1,
83+
"completion": [
84+
{
85+
"role": "assistant",
86+
"content": "The capital of Germany is Munich."
87+
}
88+
]
12189
}
122-
123-
# Initialize task spec with dataset name
124-
self.task_spec = TaskDataSpec(
125-
task_name="custom_dpo",
126-
)
127-
self.formatted_ds = formatted_ds
128-
129-
# Create temporary files using helper function
130-
train_file, val_file = make_dpo_dataset()
131-
132-
# Initialize dataset
133-
dataset = CustomDPODataset(
134-
train_data_path=train_file.name,
135-
val_data_path=val_file.name,
136-
prompt_key="context",
137-
chosen_key="chosen",
138-
rejected_key="rejected"
139-
)
140-
141-
# Test dataset properties
142-
print(f"Task name: {dataset.task_spec.task_name}")
143-
print(f"Train examples: {len(dataset.formatted_ds['train'])}")
144-
print(f"Validation examples: {len(dataset.formatted_ds['validation'])}")
145-
print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}")
146-
print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}")
147-
print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}")
90+
]
91+
}
14892
```
14993

150-
```{testoutput}
151-
Task name: custom_dpo
152-
Train examples: 2
153-
Validation examples: 2
154-
First train example prompt: What is 2+2?
155-
First train example chosen response: 4
156-
First train example rejected response: 5
94+
NeMo RL provides a DPO-compatible implementation of the [HelpSteer3](https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/data/hf_datasets/helpsteer3.py) dataset as an example. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.
95+
96+
We also provide a [PreferenceDataset](../../nemo_rl/data/hf_datasets/preference_dataset.py) class that is compatible with JSONL-formatted preference datasets. You can modify your config as follows to use such a custom preference dataset:
97+
```yaml
98+
data:
99+
dataset_name: PreferenceDataset
100+
train_data_path: <LocalPathToTrainingDataset>
101+
val_data_paths:
102+
<NameOfValidationDataset>: <LocalPathToValidationDataset>
103+
```
104+
with support for multiple validation sets achieved with:
105+
```yaml
106+
data:
107+
dataset_name: PreferenceDataset
108+
train_data_path: <LocalPathToTrainingDataset>
109+
val_data_paths:
110+
<NameOfValidationDataset1>: <LocalPathToValidationDataset1>
111+
<NameOfValidationDataset2>: <LocalPathToValidationDataset2>
157112
```
113+
Please note:
114+
- If you are using a logger, the prefix used for each validation set will be `validation-<NameOfValidationDataset>`. The total validation time, summed across all validation sets, is reported under `timing/validation/total_validation_time`.
115+
- If you are doing checkpointing, the `metric_name` value in your `checkpointing` config should reflect the metric and validation set to be tracked. For example, `validation-<NameOfValidationDataset1>_loss`.
116+
117+
The older [DPODataset](../../nemo_rl/data/hf_datasets/dpo.py) class is deprecated. This class is also compatible with JSONL-formatted preference datsets. It assumes train and validation datasets have been split and processed into the expected format offline. The JSONL files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys.
158118

159119
## DPO-Specific Parameters
160120

docs/guides/rm.md

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,84 @@ The default YAML config shares the same base template as the SFT config but incl
2121

2222
## Datasets
2323

24-
By default, NeMo RL supports the `HelpSteer3` dataset. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.
24+
Each RM dataset class is expected to have the following attributes:
25+
1. `formatted_ds`: The dictionary of formatted datasets, where each dataset should be formatted like
26+
```json
27+
{
28+
"context": [], // list of dicts - The prompt message (including previous turns, if any)
29+
"completions": [ // list of dicts — The list of completions
30+
{
31+
"rank": 0, // int — The rank of the completion (lower rank is preferred)
32+
"completion": [] // list of dicts — The completion message(s)
33+
},
34+
{
35+
"rank": 1, // int — The rank of the completion (lower rank is preferred)
36+
"completion": [] // list of dicts — The completion message(s)
37+
}
38+
]
39+
}
40+
```
41+
2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset.
42+
43+
Currently, RM training supports only two completions (where the lowest rank is preferred and the highest one is rejected), with each completion being a single response. For example:
44+
```json
45+
{
46+
"context": [
47+
{
48+
"role": "user",
49+
"content": "What's the capital of France?"
50+
},
51+
{
52+
"role": "assistant",
53+
"content": "The capital of France is Paris."
54+
},
55+
{
56+
"role": "user",
57+
"content": "Thanks! And what's the capital of Germany?"
58+
}
59+
],
60+
"completions": [
61+
{
62+
"rank": 0,
63+
"completion": [
64+
{
65+
"role": "assistant",
66+
"content": "The capital of Germany is Berlin."
67+
}
68+
]
69+
},
70+
{
71+
"rank": 1,
72+
"completion": [
73+
{
74+
"role": "assistant",
75+
"content": "The capital of Germany is Munich."
76+
}
77+
]
78+
}
79+
]
80+
}
81+
```
82+
83+
NeMo RL provides a RM-compatible implementation of the [HelpSteer3](https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/data/hf_datasets/helpsteer3.py) dataset as an example. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.
84+
85+
We also provide a [PreferenceDataset](../../nemo_rl/data/hf_datasets/preference_dataset.py) class that is compatible with JSONL-formatted preference datasets. You can modify your config as follows to use such a custom preference dataset:
86+
```yaml
87+
data:
88+
dataset_name: PreferenceDataset
89+
train_data_path: <LocalPathToTrainingDataset>
90+
val_data_paths:
91+
<NameOfValidationDataset>: <LocalPathToValidationDataset>
92+
```
93+
with support for multiple validation sets achieved with:
94+
```yaml
95+
data:
96+
dataset_name: PreferenceDataset
97+
train_data_path: <LocalPathToTrainingDataset>
98+
val_data_paths:
99+
<NameOfValidationDataset1>: <LocalPathToValidationDataset1>
100+
<NameOfValidationDataset2>: <LocalPathToValidationDataset2>
101+
```
102+
Please note:
103+
- If you are using a logger, the prefix used for each validation set will be `validation-<NameOfValidationDataset>`. The total validation time, summed across all validation sets, is reported under `timing/validation/total_validation_time`.
104+
- If you are doing checkpointing, the `metric_name` value in your `checkpointing` config should reflect the metric and validation set to be tracked. For example, `validation-<NameOfValidationDataset1>_loss`.

examples/configs/dpo.yaml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,22 @@ policy:
151151
data_parallel_sharding_strategy: "optim_grads_params"
152152

153153
data:
154-
dataset_name: "HelpSteer3"
155154
max_input_seq_length: ${policy.max_total_sequence_length}
156155
shuffle: true
156+
157+
dataset_name: HelpSteer3
158+
# You can use custom preference datasets for training and validation. For example:
159+
# data:
160+
# dataset_name: PreferenceDataset
161+
# train_data_path: <LocalPathToTrainingDataset>
162+
# val_data_paths:
163+
# <NameOfValidationDataset1>: <LocalPathToValidationDataset1>
164+
# ...
165+
# If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example:
166+
# checkpointing:
167+
# metric_name: "validation-<NameOfValidationDataset1>_loss"
168+
# ...
169+
157170
logger:
158171
log_dir: "logs" # Base directory for all logs
159172
wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running

examples/configs/rm.yaml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,21 @@ policy:
123123

124124
data:
125125
max_input_seq_length: ${policy.max_total_sequence_length}
126-
dataset_name: "HelpSteer3"
127126
shuffle: true
128127

128+
dataset_name: HelpSteer3
129+
# You can use custom preference datasets for training and validation. For example:
130+
# data:
131+
# dataset_name: PreferenceDataset
132+
# train_data_path: <LocalPathToTrainingDataset>
133+
# val_data_paths:
134+
# <NameOfValidationDataset1>: <LocalPathToValidationDataset1>
135+
# ...
136+
# If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example:
137+
# checkpointing:
138+
# metric_name: "validation-<NameOfValidationDataset1>_loss"
139+
# ...
140+
129141
logger:
130142
log_dir: "logs" # Base directory for all logs
131143
wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running

0 commit comments

Comments
 (0)