Skip to content

Commit b38248d

Browse files
authored
Merge pull request #6376 from hpcaitech/grpo-latest-rebase-main
[feat] Add distributed RLFT training framework
2 parents edd65a8 + fe1f429 commit b38248d

40 files changed

+4686
-41
lines changed

.github/workflows/run_chatgpt_examples.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
container:
2222
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
2323
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data --shm-size=10.24gb
24-
timeout-minutes: 60
24+
timeout-minutes: 180
2525
defaults:
2626
run:
2727
shell: bash
@@ -34,7 +34,12 @@ jobs:
3434
pip install --no-cache-dir -v -e .
3535
3636
- name: Install ChatGPT
37+
env:
38+
CFLAGS: "-O1"
39+
CXXFLAGS: "-O1"
40+
MAX_JOBS: 4
3741
run: |
42+
pip install flash-attn --no-build-isolation
3843
cd applications/ColossalChat
3944
pip install --no-cache-dir -v .
4045
pip install --no-cache-dir -r examples/requirements.txt

.github/workflows/run_chatgpt_unit_tests.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
container:
2222
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
2323
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data
24-
timeout-minutes: 30
24+
timeout-minutes: 180
2525
defaults:
2626
run:
2727
shell: bash
@@ -30,7 +30,12 @@ jobs:
3030
uses: actions/checkout@v2
3131

3232
- name: Install ChatGPT
33+
env:
34+
CFLAGS: "-O1"
35+
CXXFLAGS: "-O1"
36+
MAX_JOBS: 4
3337
run: |
38+
pip install flash-attn --no-build-isolation
3439
cd applications/ColossalChat
3540
pip install -v .
3641
pip install pytest

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,13 @@ coverage.xml
163163
# log, test files - ColossalChat
164164
applications/ColossalChat/logs
165165
applications/ColossalChat/tests/logs
166+
applications/ColossalChat/wandb
167+
applications/ColossalChat/model
168+
applications/ColossalChat/eval
169+
applications/ColossalChat/rollouts
170+
applications/ColossalChat/*.txt
171+
applications/ColossalChat/*.db
172+
applications/ColossalChat/stdin
173+
applications/ColossalChat/*.zip
174+
applications/ColossalChat/*.prof
175+
applications/ColossalChat/*.png

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,30 @@ def apply_chat_template_and_mask(
352352
tokenizer: PreTrainedTokenizer,
353353
chat: List[Dict[str, str]],
354354
max_length: Optional[int] = None,
355+
system_prompt: str = None,
355356
padding: bool = True,
356357
truncation: bool = True,
357358
ignore_idx: int = -100,
358359
) -> Dict[str, torch.Tensor]:
360+
361+
if system_prompt is None:
362+
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
363+
364+
system_element = {
365+
"role": "system",
366+
"content": system_prompt,
367+
}
368+
369+
# Format for RL.
370+
if "messages" in chat:
371+
gt_answer = chat.get("gt_answer", None)
372+
test_cases = chat.get("test_cases", None)
373+
chat = [chat["messages"]]
374+
359375
tokens = []
360376
assistant_mask = []
361377
for i, msg in enumerate(chat):
362-
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
378+
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
363379
# remove unexpected bos token
364380
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
365381
msg_tokens = msg_tokens[1:]
@@ -372,14 +388,10 @@ def apply_chat_template_and_mask(
372388
if max_length is not None:
373389
if padding and len(tokens) < max_length:
374390
to_pad = max_length - len(tokens)
375-
if tokenizer.padding_side == "right":
376-
tokens.extend([tokenizer.pad_token_id] * to_pad)
377-
assistant_mask.extend([False] * to_pad)
378-
attention_mask.extend([0] * to_pad)
379-
else:
380-
tokens = [tokenizer.pad_token_id] * to_pad + tokens
381-
assistant_mask = [False] * to_pad + assistant_mask
382-
attention_mask = [0] * to_pad + attention_mask
391+
# Left padding for generation.
392+
tokens = [tokenizer.pad_token_id] * to_pad + tokens
393+
assistant_mask = [False] * to_pad + assistant_mask
394+
attention_mask = [0] * to_pad + attention_mask
383395
if truncation and len(tokens) > max_length:
384396
tokens = tokens[:max_length]
385397
assistant_mask = assistant_mask[:max_length]
@@ -389,6 +401,15 @@ def apply_chat_template_and_mask(
389401
labels = input_ids.clone()
390402
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
391403

404+
if gt_answer is not None:
405+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
406+
elif test_cases is not None:
407+
return {
408+
"input_ids": input_ids,
409+
"attention_mask": attention_mask,
410+
"labels": labels,
411+
"test_cases": test_cases,
412+
}
392413
return {
393414
"input_ids": input_ids,
394415
"attention_mask": attention_mask,
@@ -402,21 +423,39 @@ class RawConversationDataset(Dataset):
402423
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
403424
"""
404425

405-
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
426+
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
406427
self.tokenizer = tokenizer
407428
self.raw_texts = []
408429
with jsonlines.open(input_file) as f:
409430
for line in f:
410431
self.raw_texts.append(line)
411432
self.tokenized_texts = [None] * len(self.raw_texts)
412433
self.max_length = max_length
434+
self.system_prompt = system_prompt
413435

414436
def __len__(self) -> int:
415437
return len(self.raw_texts)
416438

417439
def __getitem__(self, index: int):
418440
if self.tokenized_texts[index] is None:
419441
message = self.raw_texts[index]
420-
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
442+
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
421443
self.tokenized_texts[index] = dict(tokens)
422444
return self.tokenized_texts[index]
445+
446+
447+
def collate_fn_grpo(batch):
448+
input_ids = [item["input_ids"] for item in batch]
449+
attention_mask = [item["attention_mask"] for item in batch]
450+
labels = [item["labels"] for item in batch]
451+
# Assume input_ids, attention_mask, labels are already of the same length,
452+
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
453+
input_ids = torch.stack(input_ids)
454+
attention_mask = torch.stack(attention_mask)
455+
labels = torch.stack(labels)
456+
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
457+
if "test_cases" in batch[0]:
458+
ret["test_cases"] = [item["test_cases"] for item in batch]
459+
if "gt_answer" in batch[0]:
460+
ret["gt_answer"] = [item["gt_answer"] for item in batch]
461+
return ret

0 commit comments

Comments
 (0)