Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
43c9b5f
[chat] add distributed impl (#6210)
ver217 Feb 21, 2025
de282dd
[feature] fit RL style generation (#6213)
ver217 Feb 21, 2025
8e6c9a4
add reward related function
TongLi3701 Feb 23, 2025
ffd3878
add simple grpo
TongLi3701 Feb 23, 2025
f736d74
update grpo
TongLi3701 Feb 25, 2025
070907d
polish
TongLi3701 Feb 28, 2025
c15225b
modify data loader
Mar 6, 2025
b96d690
grpo consumer
Mar 6, 2025
678f5a9
update loss
Mar 6, 2025
d03cdea
update reward fn
Mar 6, 2025
7f2ceac
update example
Mar 6, 2025
812f4b7
update loader
Mar 6, 2025
0f566cc
add algo selection
Mar 6, 2025
ab5b6d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2025
0cc0c84
add save
Mar 6, 2025
0590f10
update select algo
Mar 6, 2025
22cc155
Merge branch 'grpo-latest' of github.com:hpcaitech/ColossalAI into gr…
Mar 6, 2025
eb6337f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2025
9d9d516
update grpo
Mar 10, 2025
754b16d
update reward fn
Mar 10, 2025
71a0181
update reward
Mar 10, 2025
abca66e
fix reward score
Mar 11, 2025
47d6493
add response length
Mar 11, 2025
704866a
detach
Mar 11, 2025
131eece
fix tp bug
Mar 13, 2025
afddfde
fix consumer
Mar 13, 2025
4702d57
convert to 8 generation
Mar 13, 2025
45ac6c6
print results
Mar 13, 2025
57b49da
setup update
Mar 13, 2025
bc0171d
fix transformers backend
YeAnbang Mar 14, 2025
7795d4c
[Feature] Support Distributed LogProb for GRPO Training (#6247)
duanjunwen Mar 18, 2025
7ee4452
fix vllm
YeAnbang Mar 19, 2025
0472f44
fix logprob, add filtering, temperature annealing, lr descent
YeAnbang Mar 21, 2025
d8eaf0d
simplify vllm preprocessing input ids
YeAnbang Mar 21, 2025
2aa7385
update logging
YeAnbang Mar 21, 2025
489f215
Merge pull request #6250 from hpcaitech/grpo-latest-dev
YeAnbang Mar 21, 2025
5015300
[feat] add microbatch forwarding (#6251)
YeAnbang Mar 28, 2025
ed43a4b
[Distributed RLHF] Integration of PP (#6257)
YeAnbang Apr 9, 2025
9467c10
[hot-fix] Fix memory leakage bug, support TP+PP (#6258)
YeAnbang Apr 10, 2025
03f4b1d
add prompt template (#6273)
TongLi3701 Apr 22, 2025
b823c6e
[feat] Add final save at the end (#6274)
TongLi3701 Apr 23, 2025
26d859f
[feat] Support DAPO (#6263)
YeAnbang Apr 25, 2025
3800885
fix checkpoint naming; add num_epoch parameter (#6277)
YeAnbang Apr 26, 2025
28795f5
fix save issue (#6279)
TongLi3701 Apr 27, 2025
2ca1e3c
fix pp+tp, fix dataloader (#6280)
YeAnbang Apr 28, 2025
14f237c
[feat] Support boxed math reward (#6284)
YeAnbang Apr 29, 2025
5fd4bcb
[feat] Sync shard model (#6289)
TongLi3701 Apr 30, 2025
57a8839
Support evaluation during training
YeAnbang Apr 30, 2025
bd61918
reuse comm-group
YeAnbang Apr 30, 2025
01640eb
fix bug
YeAnbang Apr 30, 2025
a6085ff
upgrade reward math verification
YeAnbang Apr 30, 2025
d06042b
rewrite reward fn
YeAnbang May 1, 2025
17928ad
Merge pull request #6292 from hpcaitech/grpo-latest-dev-reward-update
YeAnbang May 3, 2025
eb6b5dd
[fix] revert reward update and evaluation (#6295)
YeAnbang May 7, 2025
b920af4
update pad seq (#6303)
TongLi3701 May 13, 2025
47a7dc7
Support evaluation during training
YeAnbang Apr 30, 2025
aca5476
[feat] Support prompt level dynamic (#6300)
TongLi3701 May 14, 2025
50070c1
move logging to producer
YeAnbang May 14, 2025
094f119
merge
YeAnbang May 14, 2025
4ec7329
use consumer global step
YeAnbang May 15, 2025
957e3a5
disable wandb tb syncing
YeAnbang May 15, 2025
55eee12
move prompt-level-filtering to buffer side
YeAnbang May 15, 2025
a528921
move prompt-level-filtering to buffer side
YeAnbang May 15, 2025
1644adf
handle empty index
May 15, 2025
6abffb9
fix evaluation
YeAnbang May 16, 2025
ab95624
handle empty index (#6311)
TongLi3701 May 16, 2025
11a5854
remove redundant code and fix bugs
YeAnbang May 16, 2025
203dfb1
address conversation
YeAnbang May 16, 2025
021914c
support logging rollouts to wandb
YeAnbang May 16, 2025
3c42c0c
Merge pull request #6309 from hpcaitech/grpo-eval-dev
YeAnbang May 16, 2025
03b41d6
upgrade reward functions
YeAnbang May 16, 2025
107470a
fix logging rollouts
YeAnbang May 17, 2025
116621d
merge reward and eval
YeAnbang May 19, 2025
f8bd2db
add uuid to rollout log
YeAnbang May 20, 2025
bcf2459
Merge pull request #6314 from hpcaitech/grpo-reward-dev
YeAnbang May 20, 2025
32afa7b
fix empty tensor (#6319)
TongLi3701 May 20, 2025
3766338
fix metric calculation
YeAnbang May 20, 2025
88e3b09
merge grpo-latest
YeAnbang May 20, 2025
78a06f5
fix missing tags parameter
YeAnbang May 21, 2025
de2ad3b
fix default eval setting (#6321)
TongLi3701 May 22, 2025
4c36568
address conversation
YeAnbang May 28, 2025
58f8c9b
Merge branch 'grpo-latest' of https://github.com/hpcaitech/ColossalAI…
YeAnbang May 28, 2025
c8b368c
add overlength sample count (#6332)
TongLi3701 May 28, 2025
ee939d9
address conversation
YeAnbang May 29, 2025
7b921ac
merge grpo-latest
YeAnbang May 29, 2025
0d00811
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2025
96faf54
fix typ and parameter description
YeAnbang Jun 5, 2025
ceb7065
Merge pull request #6312 from hpcaitech/grpo-latest-dev
YeAnbang Jun 5, 2025
dc3033e
support code generation tasks
YeAnbang Jun 5, 2025
3bed6ae
fix bug, tested
YeAnbang Jun 9, 2025
d0e12c5
remove debug code
YeAnbang Jun 9, 2025
9ca920c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2025
c308b42
Merge pull request #6341 from hpcaitech/grpo-code
YeAnbang Jun 9, 2025
bb6f5d9
move out evaluation func (#6343)
TongLi3701 Jun 10, 2025
21d517d
Manually schedule resources and support auto master address assigning
YeAnbang Jun 10, 2025
2559924
modify readme
YeAnbang Jun 10, 2025
dc29c74
update readme
YeAnbang Jun 10, 2025
1330b57
add ray timeout handling instruction
YeAnbang Jun 10, 2025
8992def
fix pp memory issue (#6344)
TongLi3701 Jun 11, 2025
2f02a28
Update README.md
YeAnbang Jun 12, 2025
ac06935
Merge pull request #6316 from hpcaitech/grpo-support-multi-machine
YeAnbang Jun 12, 2025
51b7abe
fix num_update_per_episode
YeAnbang Jun 12, 2025
0e69b98
Merge pull request #6347 from hpcaitech/hotfix/fix_num_update_per_epi…
YeAnbang Jun 12, 2025
30a6859
optimize pp log_softmax OOM
YeAnbang Jun 13, 2025
e3d56cb
implement memory efficient logprob
YeAnbang Jun 18, 2025
6b06430
fix small bug
YeAnbang Jun 19, 2025
dd49444
Merge pull request #6348 from hpcaitech/grpo_optimization
YeAnbang Jun 19, 2025
8880b83
add dp rank for multi-dp (#6351)
TongLi3701 Jun 19, 2025
b1f646c
[feat[ Support one-behind to reduce bubble time. Add profiling code (…
YeAnbang Jun 30, 2025
a992da9
fix code evaluation
YeAnbang Jul 14, 2025
d850475
fix style
YeAnbang Jul 14, 2025
f5c155a
Merge pull request #6361 from hpcaitech/grpo-latest-fix-code-reward
YeAnbang Jul 14, 2025
4cf5ce2
add entropy (#6363)
YeAnbang Jul 17, 2025
57e9210
hotfix entropy calculation (#6364)
YeAnbang Jul 22, 2025
d7b140d
fix: wrong dp-rank condition when enable pp
liuqh16 Jul 30, 2025
d0bc901
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,13 @@ coverage.xml
# log, test files - ColossalChat
applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
applications/ColossalChat/rollouts
applications/ColossalChat/*.txt
applications/ColossalChat/*.db
applications/ColossalChat/stdin
applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png
61 changes: 50 additions & 11 deletions applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,30 @@ def apply_chat_template_and_mask(
tokenizer: PreTrainedTokenizer,
chat: List[Dict[str, str]],
max_length: Optional[int] = None,
system_prompt: str = None,
padding: bool = True,
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:

if system_prompt is None:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"

system_element = {
"role": "system",
"content": system_prompt,
}

# Format for RL.
if "messages" in chat:
gt_answer = chat.get("gt_answer", None)
test_cases = chat.get("test_cases", None)
chat = [chat["messages"]]

tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
Expand All @@ -372,14 +388,10 @@ def apply_chat_template_and_mask(
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right":
tokens.extend([tokenizer.pad_token_id] * to_pad)
assistant_mask.extend([False] * to_pad)
attention_mask.extend([0] * to_pad)
else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
# Left padding for generation.
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
Expand All @@ -389,6 +401,15 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx

if gt_answer is not None:
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
elif test_cases is not None:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"test_cases": test_cases,
}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
Expand All @@ -402,21 +423,39 @@ class RawConversationDataset(Dataset):
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
"""

def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
for line in f:
self.raw_texts.append(line)
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length
self.system_prompt = system_prompt

def __len__(self) -> int:
return len(self.raw_texts)

def __getitem__(self, index: int):
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]


def collate_fn_grpo(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Assume input_ids, attention_mask, labels are already of the same length,
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)
labels = torch.stack(labels)
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret
Loading