Skip to content

[rebase]rebase grpo-latest #6354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: grpo-latest-ascend
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
55eee12
move prompt-level-filtering to buffer side
YeAnbang May 15, 2025
a528921
move prompt-level-filtering to buffer side
YeAnbang May 15, 2025
11a5854
remove redundant code and fix bugs
YeAnbang May 16, 2025
116621d
merge reward and eval
YeAnbang May 19, 2025
3766338
fix metric calculation
YeAnbang May 20, 2025
88e3b09
merge grpo-latest
YeAnbang May 20, 2025
78a06f5
fix missing tags parameter
YeAnbang May 21, 2025
4c36568
address conversation
YeAnbang May 28, 2025
58f8c9b
Merge branch 'grpo-latest' of https://github.com/hpcaitech/ColossalAI…
YeAnbang May 28, 2025
c8b368c
add overlength sample count (#6332)
TongLi3701 May 28, 2025
ee939d9
address conversation
YeAnbang May 29, 2025
7b921ac
merge grpo-latest
YeAnbang May 29, 2025
0d00811
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2025
96faf54
fix typ and parameter description
YeAnbang Jun 5, 2025
ceb7065
Merge pull request #6312 from hpcaitech/grpo-latest-dev
YeAnbang Jun 5, 2025
dc3033e
support code generation tasks
YeAnbang Jun 5, 2025
3bed6ae
fix bug, tested
YeAnbang Jun 9, 2025
d0e12c5
remove debug code
YeAnbang Jun 9, 2025
9ca920c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2025
c308b42
Merge pull request #6341 from hpcaitech/grpo-code
YeAnbang Jun 9, 2025
bb6f5d9
move out evaluation func (#6343)
TongLi3701 Jun 10, 2025
21d517d
Manually schedule resources and support auto master address assigning
YeAnbang Jun 10, 2025
2559924
modify readme
YeAnbang Jun 10, 2025
dc29c74
update readme
YeAnbang Jun 10, 2025
1330b57
add ray timeout handling instruction
YeAnbang Jun 10, 2025
8992def
fix pp memory issue (#6344)
TongLi3701 Jun 11, 2025
2f02a28
Update README.md
YeAnbang Jun 12, 2025
ac06935
Merge pull request #6316 from hpcaitech/grpo-support-multi-machine
YeAnbang Jun 12, 2025
51b7abe
fix num_update_per_episode
YeAnbang Jun 12, 2025
0e69b98
Merge pull request #6347 from hpcaitech/hotfix/fix_num_update_per_epi…
YeAnbang Jun 12, 2025
30a6859
optimize pp log_softmax OOM
YeAnbang Jun 13, 2025
e3d56cb
implement memory efficient logprob
YeAnbang Jun 18, 2025
6b06430
fix small bug
YeAnbang Jun 19, 2025
dd49444
Merge pull request #6348 from hpcaitech/grpo_optimization
YeAnbang Jun 19, 2025
8880b83
add dp rank for multi-dp (#6351)
TongLi3701 Jun 19, 2025
d4ef7f5
[ColossalRL] Support ColossalRL on Ascend (#6324)
duanjunwen May 28, 2025
2dd59c0
[Ascend] Update README (#6331)
TongLi3701 May 28, 2025
2783dde
update to conform to json format
May 28, 2025
1304be4
[Hotfix] fix requirsments (#6338)
duanjunwen Jun 4, 2025
be5acb0
[feat][npu] Merge form grpo-latest (#6346)
xysheng-colossal Jun 23, 2025
9d7544a
Merge branch 'grpo-latest-ascend' into merge
flybird11111 Jun 23, 2025
7891cc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2025
bd8b571
fix pp
flybird11111 Jun 25, 2025
a14bd0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,9 @@ applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
applications/ColossalChat/rollouts
applications/ColossalChat/*.txt
applications/ColossalChat/*.db
applications/ColossalChat/stdin
applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png
35 changes: 27 additions & 8 deletions applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ def apply_chat_template_and_mask(
}

# Format for RL.
gt_answer = None
if "messages" in chat and "gt_answer" in chat:
gt_answer = chat["gt_answer"]
if "messages" in chat:
gt_answer = chat.get("gt_answer", None)
test_cases = chat.get("test_cases", None)
chat = [chat["messages"]]

tokens = []
Expand Down Expand Up @@ -402,12 +402,14 @@ def apply_chat_template_and_mask(
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx

if gt_answer is not None:
gt_answer = tokenizer.encode(
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
)
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}

elif test_cases is not None:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"test_cases": test_cases,
}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
Expand Down Expand Up @@ -440,3 +442,20 @@ def __getitem__(self, index: int):
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]


def collate_fn_grpo(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Assume input_ids, attention_mask, labels are already of the same length,
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)
labels = torch.stack(labels)
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret
93 changes: 93 additions & 0 deletions applications/ColossalChat/coati/distributed/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,50 @@ Then write IP node map to /etc/hosts

Set Ascend Multi-Node Config


This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM. Currently, we support two Reinforcement Learning with Verifiable Reward (RLVR) tasks: solving math problems and code generation.

**Please note that we are still under intensive development, stay tuned.**

---

## 🚀 Features

* **Distributed Training with Ray**: Scalable to multiple machines and GPUs.
* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm.
* **Model Backends**: Support `vllm` as inference backends.
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
* **Checkpoints and Logging**: Configurable intervals and directories.

---

## 🛠 Installation

### Prepare Develop Environment

Install Colossalai & ColossalChat
```bash
git clone https://github.com/hpcaitech/ColossalAI.git
git checkout grpo-latest
BUILD_EXT=1 pip install -e .

cd ./applications/ColossalChat
pip install -e .
```

Install vllm
```bash
pip install vllm==0.7.3
```

Install Ray.
```bash
pip install ray
```

Install Other Dependencies

```bash
export ATB_LLM_HCCL_ENABLE=1
export ATB_LLM_COMM_BACKEND="hccl"
Expand Down Expand Up @@ -242,6 +286,7 @@ In addition to the two default training settings we provided--- original `GRPO`
train_tensor_parallelism_size)
```


---

## 🧪 Example: single machine 8-GPU Zero2 Strategy
Expand Down Expand Up @@ -292,6 +337,54 @@ plugin_config={
}, # for pp, tp
```

```bash
# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path
# replace /datasets/train-alignment.jsonl to your dataset path
python rl_example.py
-m /path/to/Qwen2.5-Math-7B/ \
-d /path/to/train_data.jsonl \
--master_address '10.0.0.3'
-t 16 \
-i 16 \
-p GRPO-Train-Align-Debug \
-g 2 \
-ibs 1 \
-tbs 2 \
-tMbs 1 \
-tmbs 2 \
-imbs 1 \
-s "Please reason step by step, and put your final answer within \\boxed{}." \
-tMbs 8 \
-p GRPO-Train-Align-Debug \
```

## 🧪 Example: multi-machine TP+PP Strategy

### Create ray cluster on multi-machine
For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6.
We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3:
```bash
ray start --head --node-ip-address=10.0.0.3
```

Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code:
```bash
ray start --address='10.0.0.3:6379'
```

Modify plugin_config in ./applications/ColossalChat/rl_example.py
```python
plugin_config={
"tp_size": 4,
"pp_size": 2,
"microbatch_size": max(
1, args.train_microbatch_size // 2
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 1,
"max_norm": 1.0,
}, # for pp, tp
```

```bash
# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path
# replace /datasets/train-alignment.jsonl to your dataset path
Expand Down
32 changes: 17 additions & 15 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -145,7 +144,10 @@ def __init__(
def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
and self.dp_rank == 0
):
self.wandb_run = wandb.init(
project=self.project_name,
Expand Down Expand Up @@ -237,7 +239,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
pbar.set_postfix(
{
Expand Down Expand Up @@ -295,12 +296,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
)

if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
Expand All @@ -323,11 +323,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:

def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
action_log_probs = memory_efficient_logprob(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
if "reference_action_log_probs" in inputs:
per_token_kl = (
Expand Down Expand Up @@ -370,16 +370,15 @@ def _criterion(outputs, inputs):
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
else:

policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
action_log_probs = memory_efficient_logprob(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)

if self.policy_loss_fn.beta > 0:
Expand All @@ -388,11 +387,11 @@ def _criterion(outputs, inputs):
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_action_log_probs = memory_efficient_logprob(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
Expand Down Expand Up @@ -424,7 +423,10 @@ def _criterion(outputs, inputs):
mean_kl.append(kl.data)
mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
and self.dp_rank == 0
):
reward = all_reduce_mean(reward.mean(), self.plugin)
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
micro_batch_size = input_ids.size(0)
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
gt_answer = None
if "gt_answer" in kwargs:
gt_answer = kwargs.pop("gt_answer")
gt_answer = kwargs.pop("gt_answer", None)
test_cases = kwargs.pop("test_cases", None)
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
Expand Down Expand Up @@ -116,8 +115,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}

if gt_answer is not None:
# repeat gt_answer for each prompt.
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
data["gt_answer"] = gt_answer
if test_cases is not None:
data["test_cases"] = test_cases
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data

Expand Down Expand Up @@ -270,11 +270,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
}

data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}

if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
if "gt_answer" in kwargs:
data["gt_answer"] = kwargs["gt_answer"]
if "test_cases" in kwargs:
data["test_cases"] = kwargs["test_cases"]
return data

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
Expand Down
4 changes: 1 addition & 3 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
lines = f.readlines()
lines = [line for line in lines if line.strip()]
return len(lines) - 1
return len(lines)


def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
Expand All @@ -36,7 +36,6 @@ def launch_distributed(
train_batch_size: int,
train_minibatch_size: int,
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
Expand Down Expand Up @@ -121,7 +120,6 @@ def launch_distributed(
num_episodes=num_episodes,
batch_size=inference_batch_size,
train_dataset_config=train_dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
tokenizer_config=tokenizer_config,
Expand Down
Loading