Skip to content

Commit c308b42

Browse files
authored
Merge pull request #6341 from hpcaitech/grpo-code
[feat] Support Code Generation RFT, Move Reward Calculation to Producer
2 parents ceb7065 + 9ca920c commit c308b42

File tree

13 files changed

+1029
-130
lines changed

13 files changed

+1029
-130
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,7 @@ applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168168
applications/ColossalChat/eval
169169
applications/ColossalChat/rollouts
170+
applications/ColossalChat/*.txt
171+
applications/ColossalChat/*.db
172+
applications/ColossalChat/stdin
173+
applications/ColossalChat/*.zip

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ def apply_chat_template_and_mask(
367367
}
368368

369369
# Format for RL.
370-
gt_answer = None
371-
if "messages" in chat and "gt_answer" in chat:
372-
gt_answer = chat["gt_answer"]
370+
if "messages" in chat:
371+
gt_answer = chat.get("gt_answer", None)
372+
test_cases = chat.get("test_cases", None)
373373
chat = [chat["messages"]]
374374

375375
tokens = []
@@ -402,12 +402,14 @@ def apply_chat_template_and_mask(
402402
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
403403

404404
if gt_answer is not None:
405-
gt_answer = tokenizer.encode(
406-
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
407-
)
408-
gt_answer = gt_answer.squeeze(1)
409405
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
410-
406+
elif test_cases is not None:
407+
return {
408+
"input_ids": input_ids,
409+
"attention_mask": attention_mask,
410+
"labels": labels,
411+
"test_cases": test_cases,
412+
}
411413
return {
412414
"input_ids": input_ids,
413415
"attention_mask": attention_mask,
@@ -440,3 +442,20 @@ def __getitem__(self, index: int):
440442
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
441443
self.tokenized_texts[index] = dict(tokens)
442444
return self.tokenized_texts[index]
445+
446+
447+
def collate_fn_grpo(batch):
448+
input_ids = [item["input_ids"] for item in batch]
449+
attention_mask = [item["attention_mask"] for item in batch]
450+
labels = [item["labels"] for item in batch]
451+
# Assume input_ids, attention_mask, labels are already of the same length,
452+
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
453+
input_ids = torch.stack(input_ids)
454+
attention_mask = torch.stack(attention_mask)
455+
labels = torch.stack(labels)
456+
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
457+
if "test_cases" in batch[0]:
458+
ret["test_cases"] = [item["test_cases"] for item in batch]
459+
if "gt_answer" in batch[0]:
460+
ret["gt_answer"] = [item["gt_answer"] for item in batch]
461+
return ret

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,16 @@ def loop(self) -> None:
123123
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124124
# we need to calculate the metrics before filtering here for logging
125125
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
126-
raw_batch_with_reward = self.calculate_reward(
127-
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
128-
)
129-
raw_batch_with_reward = {
126+
raw_batch = {
130127
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
131-
for k, v in raw_batch_with_reward.items()
128+
for k, v in raw_batch.items()
132129
}
133130
# [batch_size, num_generations] -> [batch_size]
134-
reward = raw_batch_with_reward["reward"][:, :, 0]
135-
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
136-
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
131+
reward = raw_batch["reward"][:, :, 0]
132+
format_acc = raw_batch["format_acc"][:, :, 0]
133+
ans_acc = raw_batch["ans_acc"][:, :, 0]
137134
response_len = (
138-
raw_batch_with_reward["response_idx"][:, :, 1]
139-
- raw_batch_with_reward["response_idx"][:, :, 0]
140-
+ 1
135+
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1
141136
).type(torch.float32)
142137
effective_group_mask = None
143138
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
@@ -146,8 +141,8 @@ def loop(self) -> None:
146141
effective_group_mask = torch.logical_and(
147142
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
148143
)
149-
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
150-
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
144+
raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]]
145+
for group_idx, group_with_reward in enumerate(raw_batch):
151146
self.buffer.append(
152147
[
153148
(
@@ -163,7 +158,7 @@ def loop(self) -> None:
163158
)
164159
if effective_group_mask is not None:
165160
print(
166-
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
161+
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
167162
)
168163
# mapping the effective group to the raw group for indexing
169164
effective_group_to_raw_group_mapping = {}

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from contextlib import nullcontext
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Optional
33

44
import ray
55
import torch
66
import wandb
77
from coati.distributed.consumer import BaseConsumer
88
from coati.distributed.loss import PolicyLoss
9-
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
10-
from coati.distributed.reward.verifiable_reward import VerifiableReward
119
from coati.distributed.utils import calc_action_log_probs
1210
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1311
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -119,20 +117,7 @@ def __init__(
119117
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
120118
)
121119
# Initialize verifiable reward.
122-
response_format_tags = grpo_config.get("response_format_tags", None)
123-
reward_model_kwargs = {
124-
k: v
125-
for k, v in grpo_config.items()
126-
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
127-
}
128-
self.reward_model = VerifiableReward(
129-
reward_fns=[
130-
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
131-
],
132-
tokenizer=self.tokenizer,
133-
tags=response_format_tags,
134-
**reward_model_kwargs,
135-
)
120+
grpo_config.get("response_format_tags", None)
136121
self.global_step = 0
137122

138123
self.lr_scheduler = CosineAnnealingWarmupLR(
@@ -498,40 +483,6 @@ def _criterion(outputs, inputs):
498483
else:
499484
return None
500485

501-
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
502-
"""
503-
Calculate the group reward for the given rollout group.
504-
505-
Args:
506-
rollout_group (Dict[str, Any]):
507-
a group of samples generated by the model from the same prompt
508-
contain the following keys:
509-
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
510-
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
511-
"action_mask": torch.Tensor, [num_of_generation, response_length]
512-
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
513-
"response_idx": int, torch.Tensor, [num_of_generation, 2]
514-
"gt_answer": torch.Tensor, [num_of_generation, 128]
515-
"temperature": torch.Tensor, [] (scalar)
516-
517-
Returns:
518-
Dict[str, Any]: The new group data with calculated reward.
519-
"""
520-
reward_model_output = self.reward_model(
521-
rollout["input_ids"],
522-
gt_answer=rollout["gt_answer"],
523-
response_idx=rollout["response_idx"],
524-
)
525-
# [num_of_generation]
526-
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
527-
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
528-
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
529-
530-
rollout["reward"] = reward.view((-1, 1))
531-
rollout["format_acc"] = format_acc.view((-1, 1))
532-
rollout["ans_acc"] = ans_acc.view((-1, 1))
533-
return rollout
534-
535486
def state_dict(self):
536487
self.policy_model._force_wait_all_gather()
537488
model = self.policy_model.unwrap()

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
7474
micro_batch_size = input_ids.size(0)
7575
input_ids = input_ids.to(get_current_device())
7676
attention_mask = attention_mask.to(get_current_device())
77-
gt_answer = None
78-
if "gt_answer" in kwargs:
79-
gt_answer = kwargs.pop("gt_answer")
77+
gt_answer = kwargs.pop("gt_answer", None)
78+
test_cases = kwargs.pop("test_cases", None)
8079
if self.num_generations > 1:
8180
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
8281
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
@@ -116,8 +115,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
116115
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
117116

118117
if gt_answer is not None:
119-
# repeat gt_answer for each prompt.
120-
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
118+
data["gt_answer"] = gt_answer
119+
if test_cases is not None:
120+
data["test_cases"] = test_cases
121121
data = {k: v.to(get_current_device()) for k, v in data.items()}
122122
return data
123123

@@ -269,11 +269,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
269269
}
270270

271271
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
272-
273-
if "gt_answer" in kwargs:
274-
# repeat gt_answer for each prompt.
275-
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
276272
data = {k: v.to(get_current_device()) for k, v in data.items()}
273+
if "gt_answer" in kwargs:
274+
data["gt_answer"] = kwargs["gt_answer"]
275+
if "test_cases" in kwargs:
276+
data["test_cases"] = kwargs["test_cases"]
277277
return data
278278

279279
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def launch_distributed(
3737
train_batch_size: int,
3838
train_minibatch_size: int,
3939
train_dataset_config: Dict[str, Any],
40-
dataloaders_config: Dict[str, Any],
4140
inference_model_config: Dict[str, Any],
4241
generate_config: Dict[str, Any],
4342
train_model_config: Dict[str, Any],
@@ -89,7 +88,6 @@ def launch_distributed(
8988
num_episodes=num_episodes,
9089
batch_size=inference_batch_size,
9190
train_dataset_config=train_dataset_config,
92-
dataloaders_config=dataloaders_config,
9391
model_config=inference_model_config,
9492
generate_config=generate_config,
9593
tokenizer_config=tokenizer_config,
@@ -99,8 +97,7 @@ def launch_distributed(
9997
consumer_plugin_config=plugin_config,
10098
eval_dataset_config=eval_dataset_config,
10199
eval_interval=eval_interval,
102-
evaluation_function_type=grpo_config["reward_fn_type"],
103-
response_format_tags=grpo_config["response_format_tags"],
100+
grpo_config=grpo_config,
104101
eval_save_dir=eval_save_dir,
105102
eval_generation_config=eval_generation_config,
106103
project_name=project_name,

0 commit comments

Comments
 (0)