Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a38c104
adding support for Bradley-Terry reward model training
jveronvialard Jul 3, 2025
ede515b
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard Jul 15, 2025
5b9e976
update docs
jveronvialard Jul 15, 2025
68e96ea
add separate run_rm.py and unit tests
jveronvialard Jul 15, 2025
21d67a0
fix small typos and nit changes
jveronvialard Jul 15, 2025
8a28af7
rewards tensor shape
jveronvialard Jul 15, 2025
e914087
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard Jul 16, 2025
8fb280b
update config and skip is_tied_lm_head for RM
jveronvialard Jul 16, 2025
3e3b03a
use tokenizer.pad_token_id if model.config.pad_token_id is not defined
jveronvialard Jul 16, 2025
ed24aea
nit
jveronvialard Jul 16, 2025
af17314
update functional test and cicd
jveronvialard Jul 16, 2025
1034634
nit docs
jveronvialard Jul 17, 2025
8788ec2
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard Jul 21, 2025
24807c3
split sft.py and rm.py
jveronvialard Jul 21, 2025
d3b6272
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard Jul 22, 2025
0aaf296
pull from main
jveronvialard Jul 22, 2025
5b3f1ad
Update docs/guides/rm.md
odelalleau Jul 23, 2025
6534c7c
Remove the `-RAY_DEDUP_LOGS=0` examples in the README
odelalleau Jul 23, 2025
b79d0ee
Refactor RM config to include a dedicated `reward_model_cfg` section
odelalleau Jul 23, 2025
51cc9f8
Provide user-friendly error message regarding unsupported RMs in mcore
odelalleau Jul 23, 2025
597d5eb
Simplify code and guard against enabling sequence packing in RMs
odelalleau Jul 23, 2025
ba2e4b6
Fix likely crash with Reward Models introduced in previous commit
odelalleau Jul 23, 2025
4733717
Fix linting issues
odelalleau Jul 23, 2025
3297cd1
Fix a typing issue
odelalleau Jul 23, 2025
179767e
Quick fix to typing issue (with TODO item for better fix)
odelalleau Jul 25, 2025
a86b6c7
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard Jul 28, 2025
615aa98
Update docs/guides/rm.md
jveronvialard Jul 29, 2025
7530c84
Minor lint fix
odelalleau Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/guides/rm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Reward Model Training in NeMo RL

This document explains how to train reward models (RM) within NeMo RL. Currently, only Bradley-Terry reward models are supported.

## Launch a Training Job

The script, [examples/run_sft.py](../../examples/run_sft.py), is used to train a Bradley-Terry reward model. This script can be launched either locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md).

Be sure to launch the job using `uv`. The command to launch a training job is as follows:

```bash
uv run examples/run_sft.py --config <PATH TO YAML CONFIG> <OVERRIDES>
```

The YAML config must be specified. It uses the same base template as the SFT config but includes a new `reward_model_type` key that triggers Reward Model training. An example RM config file can be found at [examples/configs/rm.yaml](../../examples/configs/rm.yaml).

**Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models.

## Datasets

By default, NeMo RL supports the `HelpSteer3` dataset. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.
140 changes: 140 additions & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Bradley-Terry (BT) Reward Model Training Configuration
# (uses same base template as the SFT config but includes a new `reward_model_type` key that triggers Reward Model training)
sft:
## total number of steps to train will equal
## min((max_num_epochs * len(train_dataloader)), max_num_steps)
max_num_epochs: 1
max_num_steps: -1 # by default, train for 1 epoch

val_period: 16
val_batches: -1
val_global_batch_size: 32
val_micro_batch_size: 1
val_at_start: false
seed: 42

checkpointing:
enabled: true
checkpoint_dir: "results/rm"
metric_name: "val_loss"
higher_is_better: false
keep_top_k: 3
save_period: ${sft.val_period}

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
# We don't use the "default" chat template because the Llama tokenizer inserts the current
# date in the system prompt, which could make the reward model's output date-dependent.
chat_template: "{{- bos_token }}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = '' %}\n{%- endif %}\n\n{#- System message #}\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' }}\n{{- system_message }}\n{{- '<|eot_id|>' }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id>\n\n' }}\n{%- endif %}"
reward_model_type: "bradley_terry"
train_global_batch_size: 128
train_micro_batch_size: 1
max_total_sequence_length: 8192
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false

dtensor_cfg:
enabled: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: false

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
max_grad_norm: 1.0

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 2.0e-6
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
# when using Dtensor, we need to set `foreach` and `fused` to false
foreach: false
fused: false

## ignored since enabled=false, but needed for testing purposes
megatron_cfg:
enabled: false
empty_unused_memory_level: 1
activation_checkpointing: false
tensor_model_parallel_size: 2
pipeline_model_parallel_size: 2
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
sequence_parallel: false

optimizer:
optimizer: "adam"
lr: 2.0e-6
min_lr: 1.9999e-6
weight_decay: 0.1
bf16: false
fp16: false
params_dtype: "float32"

#adam
adam_beta1: 0.9
adam_beta2: 0.98
adam_eps: 1e-5

#sgd
sgd_momentum: 0.9

#distributed optimizer
use_distributed_optimizer: true
use_precision_aware_optimizer: true

clip_grad: ${policy.max_grad_norm}

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_init: 1.9999e-6

distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"


data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "HelpSteer3"

logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running
tensorboard_enabled: true
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "rm-dev"
name: "rm-dev-${data.dataset_name}"
tensorboard:
log_dir: "tb_logs-rm-dev-${data.dataset_name}"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 1
num_nodes: 1
153 changes: 114 additions & 39 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
import logging
import os
import pprint
from functools import partial
Expand Down Expand Up @@ -89,36 +90,118 @@ def sft_preprocessor(
return output


def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
def rm_preprocessor(
datum_dict: dict[str, Any],
task_data_spec: TaskDataSpec,
tokenizer,
max_seq_length: int,
idx: int,
) -> DatumSpec:
"""Process a datum dictionary for RM training."""
messages_chosen = datum_dict["prompt"] + [
{"role": "assistant", "content": datum_dict["chosen_response"]}
]
messages_rejected = datum_dict["prompt"] + [
{"role": "assistant", "content": datum_dict["rejected_response"]}
]

message_log_chosen = get_formatted_message_log(
messages_chosen, tokenizer, task_data_spec
)
message_log_rejected = get_formatted_message_log(
messages_rejected, tokenizer, task_data_spec
)

length_chosen = sum(len(m["token_ids"]) for m in message_log_chosen)
length_rejected = sum(len(m["token_ids"]) for m in message_log_rejected)

loss_multiplier = 1.0
if max(length_chosen, length_rejected) > max_seq_length:
# make smaller and mask out
logging.warning(
f"Truncating chosen and rejected messages to {max_seq_length} tokens"
)
for message in message_log_chosen:
message["token_ids"] = message["token_ids"][
: min(4, max_seq_length // len(message_log_chosen))
]
for message in message_log_rejected:
message["token_ids"] = message["token_ids"][
: min(4, max_seq_length // len(message_log_rejected))
]
loss_multiplier = 0.0

length_chosen = sum(len(m["token_ids"]) for m in message_log_chosen)
length_rejected = sum(len(m["token_ids"]) for m in message_log_rejected)

# safeguard against edge case where there are too many turns to fit within the max length
assert max(length_chosen, length_rejected) <= max_seq_length

output = {
"message_log_chosen": message_log_chosen,
"length_chosen": length_chosen,
"message_log_rejected": message_log_rejected,
"length_rejected": length_rejected,
"extra_env_info": None,
"loss_multiplier": loss_multiplier,
"idx": idx,
}
return output


def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, model_type: str):
print("\n▶ Setting up data...")
data_cls = data_config["dataset_name"]
if data_cls == "open_assistant":
data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant")
elif data_cls == "squad":
data = hf_datasets.SquadDataset()
elif data_cls == "prompt_response_dataset":
data = hf_datasets.PromptResponseDataset(
data_config["train_data_path"],
data_config["val_data_path"],
data_config["input_key"],
data_config["output_key"],
)
elif data_cls == "openmathinstruct2":
data = hf_datasets.OpenMathInstruct2Dataset(
split=data_config["split"],
output_key=data_config["output_key"],
prompt_file=data_config["prompt_file"],
)
elif data_cls == "openai_format":
data = hf_datasets.OpenAIFormatDataset(
data_config["train_data_path"],
data_config["val_data_path"],
data_config["chat_key"],
data_config["system_key"],
data_config["system_prompt"],

if model_type == "lm":
data_preprocessor = partial(
sft_preprocessor,
add_bos=data_config["add_bos"],
add_eos=data_config["add_eos"],
add_generation_prompt=data_config["add_generation_prompt"],
)

if data_cls == "open_assistant":
data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant")
elif data_cls == "squad":
data = hf_datasets.SquadDataset()
elif data_cls == "prompt_response_dataset":
data = hf_datasets.PromptResponseDataset(
data_config["train_data_path"],
data_config["val_data_path"],
data_config["input_key"],
data_config["output_key"],
)
elif data_cls == "openmathinstruct2":
data = hf_datasets.OpenMathInstruct2Dataset(
split=data_config["split"],
output_key=data_config["output_key"],
prompt_file=data_config["prompt_file"],
)
elif data_cls == "openai_format":
data = hf_datasets.OpenAIFormatDataset(
data_config["train_data_path"],
data_config["val_data_path"],
data_config["chat_key"],
data_config["system_key"],
data_config["system_prompt"],
)
else:
raise ValueError(
f"Unknown dataset class: {data_cls} for model_type: {model_type}"
)
elif model_type == "reward":
data_preprocessor = rm_preprocessor

if data_cls == "HelpSteer3":
data = hf_datasets.HelpSteer3Dataset()
else:
raise ValueError(
f"Unknown dataset class: {data_cls} for model_type: {model_type}"
)
else:
raise ValueError(f"Unknown dataset class: {data_cls}")
raise ValueError(f"Unknown model type: {model_type}")

print(
f" ✓ Training and validation datasets loaded with {len(data.formatted_ds['train'])} and {len(data.formatted_ds['validation'])} samples, respectively."
)
Expand All @@ -131,25 +214,15 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
train_dataset,
tokenizer,
sft_task_spec,
partial(
sft_preprocessor,
add_bos=data_config["add_bos"],
add_eos=data_config["add_eos"],
add_generation_prompt=data_config["add_generation_prompt"],
),
data_preprocessor,
max_seq_length=data_config["max_input_seq_length"],
)

val_dataset = AllTaskProcessedDataset(
val_dataset,
tokenizer,
sft_task_spec,
partial(
sft_preprocessor,
add_bos=data_config.get("add_bos", True),
add_eos=data_config.get("add_eos", True),
add_generation_prompt=data_config["add_generation_prompt"],
),
data_preprocessor,
max_seq_length=data_config["max_input_seq_length"],
)

Expand Down Expand Up @@ -178,6 +251,8 @@ def main():
print("Final config:")
pprint.pprint(config)

model_type = "reward" if "reward_model_type" in config["policy"] else "lm"

config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
print(f"📊 Using log directory: {config['logger']['log_dir']}")
if config["checkpointing"]["enabled"]:
Expand All @@ -195,7 +270,7 @@ def main():
dataset,
val_dataset,
sft_task_spec,
) = setup_data(tokenizer, config["data"])
) = setup_data(tokenizer, config["data"], model_type)

(
policy,
Expand Down
Loading
Loading