Skip to content

[feat] Add distributed RLFT training framework #6376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 103 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
162bb42
[chat] add distributed impl (#6210)
ver217 Feb 21, 2025
7a2d455
[feature] fit RL style generation (#6213)
ver217 Feb 21, 2025
fa1272f
add reward related function
TongLi3701 Feb 23, 2025
40d6018
add simple grpo
TongLi3701 Feb 23, 2025
1f07b71
update grpo
TongLi3701 Feb 25, 2025
718c4b7
polish
TongLi3701 Feb 28, 2025
b7842f8
modify data loader
Mar 6, 2025
5f178a7
grpo consumer
Mar 6, 2025
9754a11
update loss
Mar 6, 2025
f8899dd
update reward fn
Mar 6, 2025
5c75d5b
update example
Mar 6, 2025
cc4cc78
update loader
Mar 6, 2025
1f15dc7
add algo selection
Mar 6, 2025
88eb6e5
add save
Mar 6, 2025
246f16d
update select algo
Mar 6, 2025
f71d422
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2025
bc538ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2025
fe017d3
update grpo
Mar 10, 2025
c8db826
update reward fn
Mar 10, 2025
a537aa1
update reward
Mar 10, 2025
a4862a2
fix reward score
Mar 11, 2025
b951d0b
add response length
Mar 11, 2025
69a1a32
detach
Mar 11, 2025
b19355f
fix tp bug
Mar 13, 2025
a2ae82a
fix consumer
Mar 13, 2025
30c7ddd
convert to 8 generation
Mar 13, 2025
bfc4582
print results
Mar 13, 2025
e224673
setup update
Mar 13, 2025
35dabd7
fix transformers backend
YeAnbang Mar 14, 2025
4551853
[Feature] Support Distributed LogProb for GRPO Training (#6247)
duanjunwen Mar 18, 2025
f983071
fix vllm
YeAnbang Mar 19, 2025
16e68a0
fix logprob, add filtering, temperature annealing, lr descent
YeAnbang Mar 21, 2025
23aac43
simplify vllm preprocessing input ids
YeAnbang Mar 21, 2025
c627b60
update logging
YeAnbang Mar 21, 2025
12da4d1
[feat] add microbatch forwarding (#6251)
YeAnbang Mar 28, 2025
5d79b9e
[Distributed RLHF] Integration of PP (#6257)
YeAnbang Apr 9, 2025
3bd6fa3
[hot-fix] Fix memory leakage bug, support TP+PP (#6258)
YeAnbang Apr 10, 2025
befd4f1
add prompt template (#6273)
TongLi3701 Apr 22, 2025
b34d707
[feat] Add final save at the end (#6274)
TongLi3701 Apr 23, 2025
5f913e8
[feat] Support DAPO (#6263)
YeAnbang Apr 25, 2025
673682e
fix checkpoint naming; add num_epoch parameter (#6277)
YeAnbang Apr 26, 2025
37a8be7
fix save issue (#6279)
TongLi3701 Apr 27, 2025
fb4e507
fix pp+tp, fix dataloader (#6280)
YeAnbang Apr 28, 2025
e181318
[feat] Support boxed math reward (#6284)
YeAnbang Apr 29, 2025
6a1bd83
[feat] Sync shard model (#6289)
TongLi3701 Apr 30, 2025
16600f3
Support evaluation during training
YeAnbang Apr 30, 2025
de0c267
reuse comm-group
YeAnbang Apr 30, 2025
1be993d
fix bug
YeAnbang Apr 30, 2025
9642b75
upgrade reward math verification
YeAnbang Apr 30, 2025
06b892b
rewrite reward fn
YeAnbang May 1, 2025
9544c51
[fix] revert reward update and evaluation (#6295)
YeAnbang May 7, 2025
4ac7d06
update pad seq (#6303)
TongLi3701 May 13, 2025
af4366f
Support evaluation during training
YeAnbang Apr 30, 2025
3416a4f
move logging to producer
YeAnbang May 14, 2025
5a6e4a6
[feat] Support prompt level dynamic (#6300)
TongLi3701 May 14, 2025
280aa0b
use consumer global step
YeAnbang May 15, 2025
0d0fef7
disable wandb tb syncing
YeAnbang May 15, 2025
f79dbdb
move prompt-level-filtering to buffer side
YeAnbang May 15, 2025
d19f1f2
move prompt-level-filtering to buffer side
YeAnbang May 15, 2025
88f49dd
remove redundant code and fix bugs
YeAnbang May 16, 2025
6ebd813
handle empty index
May 15, 2025
e7f61be
fix evaluation
YeAnbang May 16, 2025
654aefc
address conversation
YeAnbang May 16, 2025
6095274
support logging rollouts to wandb
YeAnbang May 16, 2025
9cbc5dd
upgrade reward functions
YeAnbang May 16, 2025
c7c73df
fix logging rollouts
YeAnbang May 17, 2025
06cfbe3
fix metric calculation
YeAnbang May 20, 2025
70c3daa
add uuid to rollout log
YeAnbang May 20, 2025
5bbfe15
fix empty tensor (#6319)
TongLi3701 May 20, 2025
4b1c515
fix missing tags parameter
YeAnbang May 21, 2025
2a39d3a
address conversation
YeAnbang May 28, 2025
382307a
fix default eval setting (#6321)
TongLi3701 May 22, 2025
6051001
address conversation
YeAnbang May 29, 2025
a246bf2
add overlength sample count (#6332)
TongLi3701 May 28, 2025
8d52441
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2025
a9a3f37
fix typ and parameter description
YeAnbang Jun 5, 2025
1771447
support code generation tasks
YeAnbang Jun 5, 2025
de40c73
fix bug, tested
YeAnbang Jun 9, 2025
9dbb0ff
remove debug code
YeAnbang Jun 9, 2025
72b2d98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2025
6ae54a6
move out evaluation func (#6343)
TongLi3701 Jun 10, 2025
3a4681f
fix pp memory issue (#6344)
TongLi3701 Jun 11, 2025
3b3c48d
Manually schedule resources and support auto master address assigning
YeAnbang Jun 10, 2025
6a0b809
modify readme
YeAnbang Jun 10, 2025
79a7b99
update readme
YeAnbang Jun 10, 2025
80c576f
add ray timeout handling instruction
YeAnbang Jun 10, 2025
73384be
Update README.md
YeAnbang Jun 12, 2025
0f71c79
fix num_update_per_episode
YeAnbang Jun 12, 2025
a960990
optimize pp log_softmax OOM
YeAnbang Jun 13, 2025
245c8c2
implement memory efficient logprob
YeAnbang Jun 18, 2025
b314da1
fix small bug
YeAnbang Jun 19, 2025
685e0bd
add dp rank for multi-dp (#6351)
TongLi3701 Jun 19, 2025
594c2c6
[feat[ Support one-behind to reduce bubble time. Add profiling code (…
YeAnbang Jun 30, 2025
352a8e0
fix code evaluation
YeAnbang Jul 14, 2025
eafbc89
fix style
YeAnbang Jul 14, 2025
3d9dd34
add entropy (#6363)
YeAnbang Jul 17, 2025
c782976
hotfix entropy calculation (#6364)
YeAnbang Jul 22, 2025
118a66f
[Fix] Add L2 Regularization (#6372)
YeAnbang Jul 29, 2025
3746f73
fix missing or wrong file during rebase
YeAnbang Aug 5, 2025
32b2148
tested after rebasing, fix importance sampling bug
YeAnbang Aug 6, 2025
08a1244
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2025
b6a5f67
reduce memory consumption
BurkeHulk Aug 13, 2025
9db9892
reduce memory consumption
BurkeHulk Aug 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ jobs:
pip install --no-cache-dir -v -e .

- name: Install ChatGPT
env:
CFLAGS: "-O0"
CXXFLAGS: "-O0"
MAX_JOBS: 4
run: |
cd applications/ColossalChat
pip install --no-cache-dir -v .
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/run_chatgpt_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ jobs:
uses: actions/checkout@v2

- name: Install ChatGPT
env:
CFLAGS: "-O0"
CXXFLAGS: "-O0"
MAX_JOBS: 4
run: |
cd applications/ColossalChat
pip install -v .
Expand Down
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,13 @@ coverage.xml
# log, test files - ColossalChat
applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
applications/ColossalChat/rollouts
applications/ColossalChat/*.txt
applications/ColossalChat/*.db
applications/ColossalChat/stdin
applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png
61 changes: 50 additions & 11 deletions applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,30 @@ def apply_chat_template_and_mask(
tokenizer: PreTrainedTokenizer,
chat: List[Dict[str, str]],
max_length: Optional[int] = None,
system_prompt: str = None,
padding: bool = True,
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:

if system_prompt is None:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"

system_element = {
"role": "system",
"content": system_prompt,
}

# Format for RL.
if "messages" in chat:
gt_answer = chat.get("gt_answer", None)
test_cases = chat.get("test_cases", None)
chat = [chat["messages"]]

tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
Expand All @@ -372,14 +388,10 @@ def apply_chat_template_and_mask(
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right":
tokens.extend([tokenizer.pad_token_id] * to_pad)
assistant_mask.extend([False] * to_pad)
attention_mask.extend([0] * to_pad)
else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
# Left padding for generation.
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
Expand All @@ -389,6 +401,15 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx

if gt_answer is not None:
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
elif test_cases is not None:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"test_cases": test_cases,
}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
Expand All @@ -402,21 +423,39 @@ class RawConversationDataset(Dataset):
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
"""

def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
for line in f:
self.raw_texts.append(line)
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length
self.system_prompt = system_prompt

def __len__(self) -> int:
return len(self.raw_texts)

def __getitem__(self, index: int):
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]


def collate_fn_grpo(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Assume input_ids, attention_mask, labels are already of the same length,
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)
labels = torch.stack(labels)
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret
Loading
Loading