Skip to content

Commit 2559924

Browse files
committed
modify readme
1 parent 21d517d commit 2559924

File tree

2 files changed

+321
-1
lines changed

2 files changed

+321
-1
lines changed
Lines changed: 311 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,316 @@
1-
# Requirements
1+
# Distributed RL Framework for Language Model Fine-Tuning
22

3+
This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM. Currently, we supports two Reinforcement Learning with Verifiable Reward (RLVR) tasks: solving math problems and code generation.
4+
5+
**Please note that we are still under intensive development, stay tuned.**
6+
7+
---
8+
9+
## 🚀 Features
10+
11+
* **Distributed Training with Ray**: Scalable to multiple machines and GPUs.
12+
* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm.
13+
* **Model Backends**: Support `vllm` as inference backends.
14+
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
15+
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
16+
* **Checkpoints and Logging**: Configurable intervals and directories.
17+
18+
---
19+
20+
## 🛠 Installation
21+
22+
### Prepare Develop Environment
23+
24+
Install Colossalai & ColossalChat
25+
```bash
26+
git clone https://github.com/hpcaitech/ColossalAI.git
27+
git checkout grpo-latest
28+
BUILD_EXT=1 pip install -e .
29+
30+
cd ./applications/ColossalChat
31+
pip install -e .
32+
```
33+
34+
Install vllm
35+
```bash
36+
pip install vllm==0.7.3
37+
```
38+
39+
Install Ray.
40+
```bash
41+
pip install ray
42+
```
43+
44+
Install Other Dependencies
345
```bash
446
pip install cupy-cuda12x
547
python -m cupyx.tools.install_library --cuda 12.x --library nccl
648
```
49+
50+
Prepare Model & dataset
51+
```bash
52+
huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B
53+
```
54+
55+
## Architecture Design
56+
57+
<div align="center">
58+
<p align="center">
59+
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/producer-consumer-pattern.png" width=700/>
60+
</p>
61+
</div>
62+
Producer-Consumer Pattern: a classic software design pattern used for managing resources, data, or tasks between two different processes or threads.
63+
64+
* Producer: inference engine which rollouts out examples and saves them into a shared buffer.
65+
* Consumer: training framework which takes training examples from the shared buffer and train the policy model.
66+
67+
Key features for Producer-Consumer Pattern:
68+
* Buffer: Acts as a shared queue where the producer adds data and the consumer removes data.
69+
* Concurrency: Rollout and training can work concurrently.
70+
71+
## 🧠 Data Format
72+
73+
Samples in the training or evaluation `.jsonl` file should the same format depends on the type of task, we currently support two RLVR tasks: solving math problems and code generation.
74+
75+
### Math Data Format
76+
```json
77+
{
78+
"messages": {
79+
"role": "user",
80+
"content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$."
81+
},
82+
"gt_answer": "3"
83+
}
84+
```
85+
86+
### Code Data Format
87+
We support [Prime code dataset format](https://github.com/PRIME-RL/PRIME). Inputs and outputs in test cases should be two lists containing only strings and matching in the number of elements. You prompt must properly instruct the LLM to generate code to read test cases from stdin and output results to stdout.
88+
```json
89+
{
90+
"messages": {
91+
"role": "user",
92+
"content": "Solve the following coding problem using the programming language python:\n\nMikhail walks on a Cartesian plane. He starts at the point $(0, 0)$, and in one move he can go to any of eight adjacent points. For example, ..."
93+
},
94+
"test_cases": {
95+
"inputs": [
96+
"3\n2 2 3\n4 3 7\n10 1 9\n"
97+
],
98+
"outputs": [
99+
"1\n6\n-1\n"
100+
]
101+
}
102+
}
103+
```
104+
105+
---
106+
107+
## ⚙️ Hyperparameters & Arguments
108+
109+
| Argument | Description | Example |
110+
| ---------------- | --------------------------------------- | ----------------- |
111+
| `--model` | Model path or identifier | `/path/to/model` |
112+
| `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` |
113+
| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{"eval_1":"/path/to/eval_1.jsonl"}` |
114+
| `--project` | Project name | `Project1` |
115+
| `--num-episodes` | Number of training episodes | `1` |
116+
117+
### Distributed Training
118+
119+
| Argument | Description | Example |
120+
| ----------------------------- | ------------------------------------- | ------- |
121+
| `--num-trainers` | Number of trainer processes | `4` |
122+
| `--num-inferencer` | Number of inferencer processes | `4` |
123+
| `--inference-batch-size` | Prompts per inference step | `8` |
124+
| `--inference-microbatch-size` | Per-GPU batch size for inference | `8` |
125+
| `--train-batch-size` | Prompts per trainer step per dp group | `8` |
126+
| `--train-minibatch-size` | Mini-batch size before forward pass | `8` |
127+
| `--train-microbatch-size` | Per-GPU batch size for training | `2` |
128+
129+
### Sampling
130+
131+
| Argument | Description | Example |
132+
| --------------------- | --------------------- | -------------- |
133+
| `--backend` | Generation backend, choose from `vllm` | `vllm` |
134+
| `--temperature` | Sampling temperature for generation | `1.0` |
135+
| `--top-k` | Top-K sampling parameter for generation | `None` |
136+
| `--top-p` | Top-P sampling parameter for generation | `1.0` |
137+
| `--system-prompt` | System prompt, Optional, default to the default system prompt for each reward types. For more information, refer to the [**reward type**](#-constraints-and-notes) section | `Please reason step by step, and put your final answer within \\boxed{}.` |
138+
| `--max-new-tokens` | Max generation tokens | `3584` |
139+
| `--max-prompt-tokens` | Max prompt tokens | `512` |
140+
141+
### GRPO Specific
142+
143+
| Argument | Description | Example |
144+
| ----------------- | ---------------------------- | ------------------- |
145+
| `--algo` | Algorithm (`GRPO` or `DAPO`), for more customization refer to [GRPO Settings](#️-grpo-settings) | `GRPO` |
146+
| `--learning-rate` | Learning rate | `1e-6` |
147+
| `--kl-coeff` | KL penalty coefficient, if nonzero, a reference model will be used | `0.01` |
148+
| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed', 'code') For more information, refer to the [**reward type**](#-constraints-and-notes) section | `think_answer_tags` |
149+
| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `10` |
150+
151+
### Logging and Checkpointing
152+
153+
| Argument | Description | Example |
154+
| -------------------- | ------------------------- | ------------ |
155+
| `--save-interval` | Training steps between checkpoints | `20` |
156+
| `--save-dir` | Checkpoint directory | `./model` |
157+
| `--eval-save-dir` | Evaluation save path | `./eval` |
158+
| `--rollout-save-dir` | Rollout logs directory | `./rollouts` |
159+
160+
### Miscellaneous
161+
162+
| Argument | Description | Example |
163+
| ------------------ | --------------------------------------- | ------- |
164+
| `--ray_dir` | Custom Ray temp dir of a running Ray cluster (optional) | `None` |
165+
| `--master_address` | Master address of a running Ray cluster | `None` |
166+
| `--master_port` | Master port for torch DDP | `29506` |
167+
168+
---
169+
170+
## ⚙️ GRPO Settings
171+
172+
In addition to the two default training settings we provided--- original `GRPO` and `DAPO`, users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`.
173+
174+
| Argument Name | Description | Default |
175+
| ----------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
176+
| `filter_range` | Filters out rollout group if the success rate within that group is out of this range.| `[0.01, 0.99]` |
177+
| `dynamic_batching` | Enables dynamic batching as described in the [DAPO paper](https://arxiv.org/abs/2503.14476). | `True` |
178+
| `clip_eps_low` | epsilon_low in DAPO in equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.2` |
179+
| `clip_eps_high` | epsilon_high in DAPO equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.28` |
180+
| `skip_threshold` | If ratio is above this threshold, the sample is skipped to avoid instability. | `20.0` |
181+
| `loss_variation` | Type of loss variation. Supports `"token_level"` for token-wise policy gradient loss and `sample_level` for original GRPO loss. | `"token_level"` |
182+
| `soft_over_length_punishment` | Whether to use soft overlength penalty in [DAPO paper](https://arxiv.org/abs/2503.14476) or not. | `True` |
183+
| `cache_length` | `L_cache` parameter for soft overlength penalty in e.q. 13 in [DAPO paper](https://arxiv.org/abs/2503.14476) | `min(1024, int(args.max_new_tokens / 4))` |
184+
| `filter_truncated_response` | Mask out truncated responses in loss calculation. | `True` |
185+
186+
187+
188+
## 🔄 Constraints and Notes
189+
190+
* `num_inferencer + num_trainer == NUM_GPUs`
191+
* `num_inferencer % num_trainer == 0`
192+
* `(num_inferencer * inference_batch_size) % (num_trainer * train_batch_size) == 0`
193+
* `train_batch_size >= train_minibatch_size >= train_microbatch_size`
194+
* `inference_batch_size >= inference_microbatch_size`
195+
* Set microbatch sizes based on **VRAM capacity**
196+
* To use tensor parallelism on inferencer
197+
* set backend to `vllm`
198+
* change `tensor_parallel_size` in `inference_model_config` in rl_example.py
199+
* set `num_inferencer = NUM_INFERENCE_GPUs / tensor_parallel_size`
200+
* To set tensor parallelism / pipeline parallelism / zero stage
201+
* change corresponding settings in `plugin_config` in rl_example.py
202+
* Ensure rollout generation rate matches trainer consumption:
203+
204+
```
205+
num_inferencer * inference_batch_size % (
206+
num_trainer * train_batch_size /
207+
train_pipeline_parallelism_size /
208+
train_tensor_parallelism_size
209+
) == 0
210+
```
211+
* Model weights sync every:
212+
213+
```
214+
(num_inferencer * inference_batch_size) /
215+
(num_trainer * train_batch_size /
216+
train_pipeline_parallelism_size /
217+
train_tensor_parallelism_size)
218+
```
219+
* Reward Type
220+
221+
We currently support three reward types--- `think_answer_tags`, `boxed`, `code`, each varies in details such as how answer is extracted and the reward calculation process. Please select one from `think_answer_tags`, `boxed` for math problem solving and use `code` for code generation. The default system prompt for each reward type is as follows. Please make sure your system prompt provides information for the answer to be correctly extracted from model responses.
222+
223+
* think_answer_tags
224+
225+
Answer extraction: extract the content between the last `<answer>`, `</answer>` tags.
226+
227+
```
228+
You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n
229+
```
230+
* boxed
231+
232+
Answer extraction: extract the last content marked by `\\boxed{}`
233+
```
234+
Please reason step by step, and put your final answer within \\boxed{}.
235+
```
236+
* code
237+
238+
Answer extraction: extract code inside ` ```python\n...``` `
239+
```
240+
You are a helpful assistant.
241+
```
242+
generated code is
243+
---
244+
245+
## 🧪 Example: single machine 8-GPU Zero2 Strategy
246+
247+
```bash
248+
python rl_example.py \
249+
--dataset /path/to/train_data.jsonl \
250+
--model /path/to/Qwen2.5-3B/ \
251+
-t 4 -i 4 \
252+
-b vllm \
253+
-ibs 2 -tbs 4 -tMbs 1 -tmbs 4 -imbs 1 \
254+
-rt boxed \
255+
-g 4 \
256+
-ibs 1 \
257+
-tbs 2 \
258+
-tMbs 1 \
259+
-tmbs 2 \
260+
-imbs 1 \
261+
-s "Please reason step by step, and put your final answer within \\boxed{}." \
262+
-tMbs 8 \
263+
-p GRPO-Train-Align-Debug \
264+
```
265+
266+
## 🧪 Example: multi-machine TP+PP Strategy
267+
268+
### Create ray cluster on multi-machine
269+
For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6.
270+
We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3:
271+
```bash
272+
ray start --head --node-ip-address=10.0.0.3
273+
```
274+
275+
Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code:
276+
```bash
277+
ray start --address='10.0.0.3:6379'
278+
```
279+
280+
Modify plugin_config in ./applications/ColossalChat/rl_example.py
281+
```python
282+
plugin_config={
283+
"tp_size": 4,
284+
"pp_size": 2,
285+
"microbatch_size": max(
286+
1, args.train_microbatch_size // 2
287+
), # microbatch size should be set to train_microbatch_size // pp_size
288+
"zero_stage": 1,
289+
"max_norm": 1.0,
290+
}, # for pp, tp
291+
```
292+
293+
```bash
294+
# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path
295+
# replace /datasets/train-alignment.jsonl to your dataset path
296+
python rl_example.py
297+
-m /path/to/Qwen2.5-Math-7B/ \
298+
-d /path/to/train_data.jsonl \
299+
--master_address '10.0.0.3'
300+
-t 16 \
301+
-i 16 \
302+
-p GRPO-Train-Align-Debug \
303+
-g 2 \
304+
-ibs 1 \
305+
-tbs 2 \
306+
-tMbs 1 \
307+
-tmbs 2 \
308+
-imbs 1 \
309+
-b vllm \
310+
-e 2 \
311+
-rt boxed \
312+
-s "Please reason step by step, and put your final answer within \\boxed{}."
313+
```
314+
315+
## Acknowledgement
316+
Colossal-RL is a distributed version of ColossalChat and inspired by a few awesome open-source projects. We would like to express our gratitude to the Fuyao-ray team and the vllm-ascend team for their support throughout the development of the this project. We also thank the following awesome open-source projects and algorithms: GRPO, DAPO, TRL, Verl, OpenRLHF, StreamRL, Qwen, Logic-RL.

applications/ColossalChat/rl_example.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
import torch
77
from coati.distributed.launch import launch_distributed
88

9+
DEFAUT_SYSTEM_PROMPT = {
10+
"think_answer_tags": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n",
11+
"boxed": "Please reason step by step, and put your final answer within \\boxed{}.",
12+
"code": "You are a helpful assistant.",
13+
}
14+
915
if __name__ == "__main__":
1016
parser = argparse.ArgumentParser()
1117
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@@ -295,6 +301,10 @@
295301
else:
296302
raise ValueError(f"Unsupported algorithm: {args.algo}")
297303

304+
if args.system_prompt is None:
305+
# Default system prompt
306+
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
307+
298308
launch_distributed(
299309
num_producers=args.num_inferencer,
300310
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),

0 commit comments

Comments
 (0)