-
Notifications
You must be signed in to change notification settings - Fork 307
feat: adding support for Bradley-Terry reward model training #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
a38c104
adding support for Bradley-Terry reward model training
jveronvialard ede515b
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard 5b9e976
update docs
jveronvialard 68e96ea
add separate run_rm.py and unit tests
jveronvialard 21d67a0
fix small typos and nit changes
jveronvialard 8a28af7
rewards tensor shape
jveronvialard e914087
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard 8fb280b
update config and skip is_tied_lm_head for RM
jveronvialard 3e3b03a
use tokenizer.pad_token_id if model.config.pad_token_id is not defined
jveronvialard ed24aea
nit
jveronvialard af17314
update functional test and cicd
jveronvialard 1034634
nit docs
jveronvialard 8788ec2
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard 24807c3
split sft.py and rm.py
jveronvialard d3b6272
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard 0aaf296
pull from main
jveronvialard 5b3f1ad
Update docs/guides/rm.md
odelalleau 6534c7c
Remove the `-RAY_DEDUP_LOGS=0` examples in the README
odelalleau b79d0ee
Refactor RM config to include a dedicated `reward_model_cfg` section
odelalleau 51cc9f8
Provide user-friendly error message regarding unsupported RMs in mcore
odelalleau 597d5eb
Simplify code and guard against enabling sequence packing in RMs
odelalleau ba2e4b6
Fix likely crash with Reward Models introduced in previous commit
odelalleau 4733717
Fix linting issues
odelalleau 3297cd1
Fix a typing issue
odelalleau 179767e
Quick fix to typing issue (with TODO item for better fix)
odelalleau a86b6c7
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into jveronvialard/b…
jveronvialard 615aa98
Update docs/guides/rm.md
jveronvialard 7530c84
Minor lint fix
odelalleau File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Reward Model Training in NeMo RL | ||
|
|
||
| This document explains how to train reward models (RM) within NeMo RL. Currently, only Bradley-Terry reward models are supported. | ||
|
|
||
| ## Launch a Training Job | ||
|
|
||
| The script, [examples/run_sft.py](../../examples/run_sft.py), is used to train a Bradley-Terry reward model. This script can be launched either locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md). | ||
|
|
||
| Be sure to launch the job using `uv`. The command to launch a training job is as follows: | ||
|
|
||
| ```bash | ||
| uv run examples/run_sft.py --config <PATH TO YAML CONFIG> <OVERRIDES> | ||
jveronvialard marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ``` | ||
|
|
||
| The YAML config must be specified. It uses the same base template as the SFT config but includes a new `reward_model_type` key that triggers Reward Model training. An example RM config file can be found at [examples/configs/rm.yaml](../../examples/configs/rm.yaml). | ||
jveronvialard marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| **Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. | ||
jveronvialard marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ## Datasets | ||
|
|
||
| By default, NeMo RL supports the `HelpSteer3` dataset. This dataset is downloaded from Hugging Face and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| # Bradley-Terry (BT) Reward Model Training Configuration | ||
| # (uses same base template as the SFT config but includes a new `reward_model_type` key that triggers Reward Model training) | ||
| sft: | ||
jveronvialard marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ## total number of steps to train will equal | ||
| ## min((max_num_epochs * len(train_dataloader)), max_num_steps) | ||
| max_num_epochs: 1 | ||
| max_num_steps: -1 # by default, train for 1 epoch | ||
|
|
||
| val_period: 16 | ||
| val_batches: -1 | ||
| val_global_batch_size: 32 | ||
| val_micro_batch_size: 1 | ||
| val_at_start: false | ||
| seed: 42 | ||
|
|
||
| checkpointing: | ||
| enabled: true | ||
| checkpoint_dir: "results/rm" | ||
| metric_name: "val_loss" | ||
| higher_is_better: false | ||
| keep_top_k: 3 | ||
| save_period: ${sft.val_period} | ||
|
|
||
| policy: | ||
| model_name: "meta-llama/Llama-3.2-1B-Instruct" | ||
| tokenizer: | ||
| name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default | ||
| # We don't use the "default" chat template because the Llama tokenizer inserts the current | ||
| # date in the system prompt, which could make the reward model's output date-dependent. | ||
| chat_template: "{{- bos_token }}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = '' %}\n{%- endif %}\n\n{#- System message #}\n{{- '<|start_header_id|>system<|end_header_id|>\n\n' }}\n{{- system_message }}\n{{- '<|eot_id|>' }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id>\n\n' }}\n{%- endif %}" | ||
| reward_model_type: "bradley_terry" | ||
| train_global_batch_size: 128 | ||
| train_micro_batch_size: 1 | ||
| max_total_sequence_length: 8192 | ||
| precision: "bfloat16" | ||
| fsdp_offload_enabled: false | ||
| activation_checkpointing_enabled: false | ||
|
|
||
| dtensor_cfg: | ||
| enabled: true | ||
| cpu_offload: false | ||
| sequence_parallel: false | ||
| activation_checkpointing: false | ||
| tensor_parallel_size: 1 | ||
| context_parallel_size: 1 | ||
| custom_parallel_plan: null | ||
|
|
||
| dynamic_batching: | ||
| enabled: false | ||
|
|
||
| # makes the training sequence length divisible by the tensor parallel size | ||
| # this is useful for sequence parallel training | ||
| make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} | ||
| max_grad_norm: 1.0 | ||
|
|
||
| optimizer: | ||
| name: "torch.optim.AdamW" | ||
| kwargs: | ||
| lr: 2.0e-6 | ||
| weight_decay: 0.1 | ||
| betas: [0.9, 0.98] | ||
| eps: 1e-5 | ||
| # when using Dtensor, we need to set `foreach` and `fused` to false | ||
| foreach: false | ||
| fused: false | ||
|
|
||
| ## ignored since enabled=false, but needed for testing purposes | ||
| megatron_cfg: | ||
| enabled: false | ||
| empty_unused_memory_level: 1 | ||
| activation_checkpointing: false | ||
| tensor_model_parallel_size: 2 | ||
| pipeline_model_parallel_size: 2 | ||
| context_parallel_size: 1 | ||
| pipeline_dtype: ${policy.precision} | ||
| num_layers_in_first_pipeline_stage: null | ||
| num_layers_in_last_pipeline_stage: null | ||
| sequence_parallel: false | ||
|
|
||
| optimizer: | ||
| optimizer: "adam" | ||
| lr: 2.0e-6 | ||
| min_lr: 1.9999e-6 | ||
| weight_decay: 0.1 | ||
| bf16: false | ||
| fp16: false | ||
| params_dtype: "float32" | ||
|
|
||
| #adam | ||
| adam_beta1: 0.9 | ||
| adam_beta2: 0.98 | ||
| adam_eps: 1e-5 | ||
|
|
||
| #sgd | ||
| sgd_momentum: 0.9 | ||
|
|
||
| #distributed optimizer | ||
| use_distributed_optimizer: true | ||
| use_precision_aware_optimizer: true | ||
|
|
||
| clip_grad: ${policy.max_grad_norm} | ||
|
|
||
| scheduler: | ||
| start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} | ||
| end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} | ||
| weight_decay_incr_style: "constant" | ||
| lr_decay_style: "constant" | ||
| lr_decay_iters: null | ||
| lr_warmup_iters: 50 | ||
| lr_warmup_init: 1.9999e-6 | ||
|
|
||
| distributed_data_parallel_config: | ||
| grad_reduce_in_fp32: false | ||
| overlap_grad_reduce: true | ||
| overlap_param_gather: false | ||
| average_in_collective: true | ||
| data_parallel_sharding_strategy: "optim_grads_params" | ||
|
|
||
|
|
||
| data: | ||
| max_input_seq_length: ${policy.max_total_sequence_length} | ||
| dataset_name: "HelpSteer3" | ||
|
|
||
| logger: | ||
| log_dir: "logs" # Base directory for all logs | ||
| wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running | ||
| tensorboard_enabled: true | ||
| monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard | ||
| wandb: | ||
| project: "rm-dev" | ||
| name: "rm-dev-${data.dataset_name}" | ||
| tensorboard: | ||
| log_dir: "tb_logs-rm-dev-${data.dataset_name}" | ||
| gpu_monitoring: | ||
| collection_interval: 10 # How often to collect GPU usage metrics (in seconds) | ||
| flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) | ||
|
|
||
| cluster: | ||
| gpus_per_node: 1 | ||
| num_nodes: 1 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.