Skip to content

Commit bde1a68

Browse files
mrm-196rootrootrootroot
authored
feat: Add recipe to reproduce Tulu-3 DPO model (#804)
Signed-off-by: root <[email protected]> Signed-off-by: root <[email protected]> Signed-off-by: root <[email protected]> Signed-off-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]>
1 parent eb50202 commit bde1a68

File tree

7 files changed

+181
-9
lines changed

7 files changed

+181
-9
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
defaults: "../../dpo.yaml"
2+
3+
cluster:
4+
num_nodes: 1
5+
gpus_per_node: 8
6+
7+
policy:
8+
model_name: "allenai/Llama-3.1-Tulu-3-8B-SFT"
9+
tokenizer:
10+
name: "allenai/Llama-3.1-Tulu-3-8B-SFT"
11+
train_micro_batch_size: 1
12+
train_global_batch_size: 128
13+
max_total_sequence_length: 2048
14+
optimizer:
15+
name: "torch.optim.AdamW"
16+
kwargs:
17+
lr: 5.0e-7
18+
weight_decay: 0.0
19+
scheduler:
20+
- name: "torch.optim.lr_scheduler.LinearLR"
21+
kwargs:
22+
start_factor: 1.0e-6
23+
end_factor: 1.0
24+
total_iters: 211
25+
- name: "torch.optim.lr_scheduler.LinearLR"
26+
kwargs:
27+
start_factor: 1.0
28+
end_factor: 0.0
29+
total_iters: 1899
30+
- milestones: [211]
31+
32+
data:
33+
dataset_name: "Tulu3Preference"
34+
35+
dpo:
36+
max_num_steps: 2110
37+
val_period: -1
38+
val_at_start: false
39+
preference_average_log_probs: True
40+
reference_policy_kl_penalty: 5
41+
val_micro_batch_size: ${policy.train_micro_batch_size}
42+
val_global_batch_size: ${policy.train_global_batch_size}
43+
44+
checkpointing:
45+
metric_name: null
46+
save_period: 250
47+
48+
logger:
49+
wandb_enabled: True
50+
wandb:
51+
name: "dpo-tulu3-8b"

examples/run_dpo.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,19 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig):
176176

177177
if data_config["dataset_name"] == "HelpSteer3":
178178
data = hf_datasets.HelpSteer3Dataset()
179+
train_dataset = data.formatted_ds["train"]
180+
val_dataset = data.formatted_ds["validation"]
181+
elif data_config["dataset_name"] == "Tulu3Preference":
182+
data = hf_datasets.Tulu3PreferenceDataset()
183+
train_dataset = data.formatted_ds["train"]
184+
val_dataset = None
179185
else:
180186
data = hf_datasets.DPODataset(
181187
train_data_path=data_config["train_data_path"],
182188
val_data_path=data_config["val_data_path"],
183189
)
184-
train_dataset = data.formatted_ds["train"]
185-
val_dataset = data.formatted_ds["validation"]
190+
train_dataset = data.formatted_ds["train"]
191+
val_dataset = data.formatted_ds["validation"]
186192

187193
dpo_task_spec = data.task_spec
188194

@@ -195,13 +201,14 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig):
195201
max_seq_length=data_config["max_input_seq_length"],
196202
)
197203

198-
val_dataset = AllTaskProcessedDataset(
199-
val_dataset,
200-
tokenizer,
201-
dpo_task_spec,
202-
dpo_preprocessor,
203-
max_seq_length=data_config["max_input_seq_length"],
204-
)
204+
if val_dataset:
205+
val_dataset = AllTaskProcessedDataset(
206+
val_dataset,
207+
tokenizer,
208+
dpo_task_spec,
209+
dpo_preprocessor,
210+
max_seq_length=data_config["max_input_seq_length"],
211+
)
205212

206213
return train_dataset, val_dataset, tokenizer, dpo_task_spec
207214

nemo_rl/data/hf_datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
PromptResponseDataset,
2424
)
2525
from nemo_rl.data.hf_datasets.squad import SquadDataset
26+
from nemo_rl.data.hf_datasets.tulu3 import Tulu3PreferenceDataset
2627

2728
__all__ = [
2829
"DPODataset",
@@ -32,6 +33,7 @@
3233
"OpenMathInstruct2Dataset",
3334
"PromptResponseDataset",
3435
"SquadDataset",
36+
"Tulu3PreferenceDataset",
3537
"COMMON_CHAT_TEMPLATES",
3638
"CLEVRCoGenTDataset",
3739
]

nemo_rl/data/hf_datasets/tulu3.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from typing import Any
17+
18+
from datasets import load_dataset
19+
20+
from nemo_rl.data.interfaces import TaskDataSpec
21+
22+
23+
def format_tulu3_preference(data: dict[str, Any]) -> dict[str, str | dict[str, str]]:
24+
chosen_conversation = data["chosen"]
25+
rejected_conversation = data["rejected"]
26+
27+
context = chosen_conversation[:-1]
28+
29+
# We assume that except last assistant response, all messages in
30+
# chosen and rejected conversations are similar. Validating this...
31+
assert json.dumps(context, ensure_ascii=False) == json.dumps(
32+
rejected_conversation[:-1], ensure_ascii=False
33+
), (
34+
f"Context mismatch.\n\nchosen: {chosen_conversation}\n\n rejected: {rejected_conversation}"
35+
)
36+
37+
# We assume that last response is always from the assistant. Validating this...
38+
assert chosen_conversation[-1]["role"] == "assistant", (
39+
f"The last chosen response ({chosen_conversation[-1]}) is not from assistant!"
40+
)
41+
assert rejected_conversation[-1]["role"] == "assistant", (
42+
f"The last rejected response ({rejected_conversation[-1]}) is not from assistant!"
43+
)
44+
45+
chosen_response = chosen_conversation[-1]["content"]
46+
rejected_response = rejected_conversation[-1]["content"]
47+
48+
return {
49+
"prompt": context,
50+
"chosen_response": chosen_response,
51+
"rejected_response": rejected_response,
52+
}
53+
54+
55+
class Tulu3PreferenceDataset:
56+
"""Tulu3 preference dataset for DPO training."""
57+
58+
def __init__(self) -> None:
59+
ds = load_dataset(
60+
path="allenai/llama-3.1-tulu-3-8b-preference-mixture",
61+
trust_remote_code=True,
62+
)
63+
self.formatted_ds = ds.map(format_tulu3_preference)
64+
65+
self.task_spec = TaskDataSpec(
66+
task_name="Tulu3Preference",
67+
)

pyrefly.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ project-includes = [
6262
"nemo_rl/data/hf_datasets/openmathinstruct2.py",
6363
"nemo_rl/data/hf_datasets/prompt_response_dataset.py",
6464
"nemo_rl/data/hf_datasets/squad.py",
65+
"nemo_rl/data/hf_datasets/tulu3.py",
6566
"nemo_rl/data/multimodal_utils.py",
6667
"nemo_rl/data/interfaces.py",
6768
"nemo_rl/data/packing/__init__.py",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
3+
source $SCRIPT_DIR/common.env
4+
5+
# ===== BEGIN CONFIG =====
6+
NUM_NODES=1
7+
STEPS_PER_RUN=150
8+
MAX_STEPS=150
9+
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10+
NUM_MINUTES=45
11+
# ===== END CONFIG =====
12+
13+
exit_if_max_steps_reached
14+
15+
# Run the experiment
16+
cd $PROJECT_ROOT
17+
uv run examples/run_dpo.py \
18+
--config $CONFIG_PATH \
19+
dpo.max_num_steps=$MAX_STEPS \
20+
logger.log_dir=$LOG_DIR \
21+
logger.wandb_enabled=True \
22+
logger.wandb.project=nemo-rl \
23+
logger.wandb.name=$EXP_NAME \
24+
logger.monitor_gpus=True \
25+
logger.tensorboard_enabled=True \
26+
checkpointing.enabled=True \
27+
checkpointing.checkpoint_dir=$CKPT_DIR \
28+
$@ \
29+
2>&1 | tee $RUN_LOG
30+
31+
# Convert tensorboard logs to json
32+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
33+
34+
# Only run metrics if the target step is reached
35+
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
36+
uv run tests/check_metrics.py $JSON_METRICS \
37+
'data["train/sft_loss"]["1"] < 0.00001' \
38+
'data["train/sft_loss"]["150"] < 0.00001' \
39+
'data["train/preference_loss"]["1"] > 0.6930' \
40+
'data["train/preference_loss"]["1"] < 0.6932' \
41+
'data["train/preference_loss"]["150"] < 0.68'
42+
fi
43+

tests/test_suites/release.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ tests/test_suites/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.sh
2525
# Long 8b convergence
2626
tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.sh
2727
tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.sh
28+
tests/test_suites/llm/dpo-llama3.1-8b-tulu3-1n8g-fsdp2tp1.sh

0 commit comments

Comments
 (0)