|
| 1 | +# RLHF example |
| 2 | + |
| 3 | +This example uses RLHF (Reinforcement Learning with Human Feedback) to train a |
| 4 | +language model to summarize Reddit posts. |
| 5 | + |
| 6 | +## Getting started |
| 7 | + |
| 8 | +Make sure you have PyTorch>=2.0 installed. You can find installation instructions |
| 9 | +[here](https://pytorch.org/get-started/locally/). |
| 10 | + |
| 11 | +From this directory, you can install extra requirements for running these |
| 12 | +examples with |
| 13 | + |
| 14 | +```sh |
| 15 | +pip install -r requirements.txt |
| 16 | +``` |
| 17 | + |
| 18 | +## Training the models |
| 19 | +### Training the transformer |
| 20 | + |
| 21 | +Once the data has been prepared, you can train the GPT model. |
| 22 | + |
| 23 | +```sh |
| 24 | +python train.py |
| 25 | +``` |
| 26 | + |
| 27 | +Default configuration can be found in `config/train.yaml`, and any option can |
| 28 | +be overridden with command-line arguments, for example to run the training |
| 29 | +script with a different batch size: |
| 30 | + |
| 31 | +```sh |
| 32 | +python train.py --batch_size=128 |
| 33 | +``` |
| 34 | +> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` |
| 35 | +> and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback |
| 36 | +
|
| 37 | +### Training the reward model |
| 38 | + |
| 39 | +Once you have completed supervised fine-tuning, copy the desired model |
| 40 | +checkpoint to `./out` or update the config to point `model.name_or_path` at |
| 41 | +the relevant checkpoint in the timestamped working directory created by Hydra. |
| 42 | +You can then train the reward model with: |
| 43 | + |
| 44 | +```sh |
| 45 | +python train_reward.py |
| 46 | +``` |
| 47 | + |
| 48 | +### Training the final model with RLHF |
| 49 | + |
| 50 | +Once again, make sure you have either updated the configuration to point |
| 51 | +`reward_model.name_or_path` at the relevant timestamped working directory, or |
| 52 | +copy the checkpoint to `./out_reward`. |
| 53 | +You can then train the final model by running |
| 54 | + |
| 55 | +```sh |
| 56 | +python train_rlhf.py |
| 57 | +``` |
0 commit comments