Skip to content

Commit fe19cf5

Browse files
Vincent Moensapbardtcbegley
authored
[Algorithm] RLHF end-to-end, clean (#1597)
Co-authored-by: Alessandro Pietro Bardelli <[email protected]> Co-authored-by: Tom Begley <[email protected]>
1 parent f09b0c8 commit fe19cf5

File tree

26 files changed

+1402
-38
lines changed

26 files changed

+1402
-38
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,10 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac
282282
train.minibatch_size=100 \
283283
logger.backend=
284284

285-
286285
python .github/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100
287286

287+
## RLHF
288+
# RLHF tests are executed in the dedicated workflow
289+
288290
coverage combine
289291
coverage xml -i

.github/unittest/linux_libs/scripts_rlhf/run_test.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,14 @@ conda deactivate && conda activate ./env
2222
python -c "import transformers, datasets"
2323

2424
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
25+
26+
python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \
27+
sys.device=cuda:0 sys.ref_device=cuda:0 \
28+
model.name_or_path=gpt2 train.max_epochs=2 \
29+
data.batch_size=2 train.ppo.ppo_batch_size=2 \
30+
train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \
31+
train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \
32+
data.block_size=110 io.logger=csv
33+
2534
coverage combine
2635
coverage xml -i

examples/rlhf/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.png
2+
*.bin
3+
*.pt
4+
*.json

examples/rlhf/README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# RLHF example
2+
3+
This example uses RLHF (Reinforcement Learning with Human Feedback) to train a
4+
language model to summarize Reddit posts.
5+
6+
## Getting started
7+
8+
Make sure you have PyTorch>=2.0 installed. You can find installation instructions
9+
[here](https://pytorch.org/get-started/locally/).
10+
11+
From this directory, you can install extra requirements for running these
12+
examples with
13+
14+
```sh
15+
pip install -r requirements.txt
16+
```
17+
18+
## Training the models
19+
### Training the transformer
20+
21+
Once the data has been prepared, you can train the GPT model.
22+
23+
```sh
24+
python train.py
25+
```
26+
27+
Default configuration can be found in `config/train.yaml`, and any option can
28+
be overridden with command-line arguments, for example to run the training
29+
script with a different batch size:
30+
31+
```sh
32+
python train.py --batch_size=128
33+
```
34+
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps`
35+
> and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback
36+
37+
### Training the reward model
38+
39+
Once you have completed supervised fine-tuning, copy the desired model
40+
checkpoint to `./out` or update the config to point `model.name_or_path` at
41+
the relevant checkpoint in the timestamped working directory created by Hydra.
42+
You can then train the reward model with:
43+
44+
```sh
45+
python train_reward.py
46+
```
47+
48+
### Training the final model with RLHF
49+
50+
Once again, make sure you have either updated the configuration to point
51+
`reward_model.name_or_path` at the relevant timestamped working directory, or
52+
copy the checkpoint to `./out_reward`.
53+
You can then train the final model by running
54+
55+
```sh
56+
python train_rlhf.py
57+
```

examples/rlhf/config/train.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
io:
2+
eval_interval: 200
3+
log_interval: 50
4+
eval_iters: 100
5+
data:
6+
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
7+
block_size: 550
8+
model:
9+
name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint
10+
out_dir: ./out
11+
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
12+
train:
13+
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
14+
max_iters: 5000 # total number of training iterations
15+
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
16+
always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir
17+
decay_lr: True # whether to decay the learning rate
18+
optimizer:
19+
# keyword arguments for torch.optim.AdamW
20+
lr: 1.0e-5
21+
weight_decay: 1.0e-1
22+
betas: [0.9, 0.95]
23+
scheduler:
24+
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
25+
T_max: 5000 # maximum number of iterations
26+
eta_min: 1.0e-6 # minimum learning rate
27+
sys:
28+
device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks
29+
dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler
30+
compile: True # use PyTorch 2.0 to compile the model to be faster
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
io:
2+
eval_interval: 200
3+
log_interval: 50
4+
eval_iters: 100
5+
data:
6+
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
7+
block_size: 550
8+
model:
9+
name_or_path: ./out
10+
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
11+
reward_model:
12+
out_dir: ./out_reward
13+
init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward
14+
train:
15+
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
16+
max_iters: 20000 # total number of training iterations
17+
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
18+
always_save_checkpoint: False # if True, always save a checkpoint after each eval
19+
decay_lr: False # whether to decay the learning rate
20+
optimizer:
21+
# keyword arguments for torch.optim.AdamW
22+
lr: 1.0e-5
23+
weight_decay: 1.0e-1
24+
betas: [0.9, 0.95]
25+
scheduler:
26+
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
27+
T_max: 20000
28+
eta_min: 1.0e-6
29+
sys:
30+
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
31+
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
32+
compile: True # use PyTorch 2.0 to compile the model to be faster
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
io:
2+
eval_interval: 6
3+
log_interval: 1
4+
eval_iters: 10
5+
logger: wandb
6+
data:
7+
batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
8+
block_size: 550
9+
num_workers: 1
10+
model:
11+
name_or_path: ./out
12+
out_dir: ./out_rlhf
13+
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
14+
reward_model:
15+
name_or_path: ./out_reward
16+
train:
17+
grad_clip: 1.0
18+
max_epochs: 1000 # total number of training iterations
19+
always_save_checkpoint: True # if True, always save a checkpoint after each eval
20+
decay_lr: True
21+
optimizer:
22+
# keyword arguments for torch.optim.AdamW
23+
lr: 5.0e-5
24+
weight_decay: 0.0 # 01
25+
betas: [0.9, 0.999]
26+
scheduler:
27+
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
28+
T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size
29+
eta_min: 5.0e-6
30+
ppo:
31+
episode_length: 50
32+
ppo_batch_size: 16
33+
ppo_num_epochs: 3
34+
num_rollouts_per_epoch: 32
35+
sys:
36+
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
37+
ref_device: cuda:1 # device of reference model
38+
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
39+
compile: False # use PyTorch 2.0 to compile the model to be faster

examples/rlhf/data/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr
2+
3+
__all__ = ["get_prompt_dataloader_tldr"]

examples/rlhf/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator
6+
from torchrl.modules.tensordict_module.common import VmapModule
7+
8+
from .transformer import init_transformer
9+
10+
__all__ = ["init_actor_critic"]
11+
12+
13+
def init_actor_critic(model_cfg, sys_cfg):
14+
15+
transformer_name_or_path = model_cfg.name_or_path
16+
dropout = model_cfg.dropout
17+
18+
device = sys_cfg.device
19+
compile_model = sys_cfg.compile
20+
base_model = init_transformer(
21+
transformer_name_or_path,
22+
dropout,
23+
device,
24+
as_tensordictmodule=False,
25+
compile_model=compile_model,
26+
inference=True,
27+
)
28+
model = LMHeadActorValueOperator(base_model)
29+
model.to(device)
30+
model.eval()
31+
actor = model.get_policy_operator()
32+
critic = model.get_value_operator()
33+
critic_head = model.get_value_head()
34+
35+
return actor, VmapModule(critic), critic_head, base_model

0 commit comments

Comments
 (0)