Skip to content

Commit a9cf3aa

Browse files
committed
Merge branch 'grpo-latest' into grpo-latest-npu
2 parents 2305f93 + 137ec17 commit a9cf3aa

File tree

6 files changed

+79
-28
lines changed

6 files changed

+79
-28
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def setup(self) -> None:
7474
)
7575
if plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
7676
plugin_config["microbatch_size"] = self.microbatch_size
77+
if self.plugin_config.get("tp_size", 1) > 1:
78+
plugin_config["parallel_output"] = False
7779
plugin_config.update(self.plugin_config)
7880
self.plugin = HybridParallelPlugin(**plugin_config)
7981
self.booster = Booster(plugin=self.plugin)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def __init__(
5656
self.accum_loss = torch.zeros(1, device=self.device)
5757
self.accum_reward = torch.zeros(1, device=self.device)
5858
self.accum_kl = torch.zeros(1, device=self.device)
59+
self.accum_format_reward = torch.zeros(1, device=self.device)
60+
self.accum_acc_reward = torch.zeros(1, device=self.device)
61+
self.accum_advantages = torch.zeros(1, device=self.device)
62+
self.accum_response_length = torch.zeros(1, device=self.device)
5963
self.accum_count = 0
6064

6165
# Reference model is initialized from policy model.
@@ -80,7 +84,7 @@ def __init__(
8084
self.policy_loss_fn = PolicyLoss()
8185
self.global_step = 0
8286
if use_wandb and self.rank == 0:
83-
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
87+
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
8488

8589
def setup(self):
8690
super().setup()
@@ -106,6 +110,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
106110
action_mask = data["action_mask"]
107111
num_action = action_mask.shape[1]
108112
old_action_log_probs = data["action_log_probs"]
113+
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
109114

110115
need_update = (step_idx + 1) % self.num_microbatches == 0
111116

@@ -133,9 +138,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
133138
)
134139
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
135140

136-
reward = self.reward_model(
141+
reward_group = self.reward_model(
137142
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
138143
)
144+
145+
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
146+
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
147+
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
148+
139149
# [batch_size, num_generations]
140150
group_reward = reward.view(-1, self.num_generations)
141151

@@ -159,9 +169,18 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
159169
loss = all_reduce_mean(loss, self.plugin)
160170
reward = all_reduce_mean(reward.mean(), self.plugin)
161171
kl = all_reduce_mean(kl.mean(), self.plugin)
172+
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
173+
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
174+
advantages = all_reduce_mean(advantages.mean(), self.plugin)
175+
response_length = all_reduce_mean(response_length.mean(), self.plugin)
176+
# Calculate accumulate value.
162177
self.accum_loss.add_(loss.data)
163178
self.accum_reward.add_(reward.data)
164179
self.accum_kl.add_(kl.data)
180+
self.accum_format_reward.add_(format_reward.data)
181+
self.accum_acc_reward.add_(acc_reward.data)
182+
self.accum_advantages.add_(advantages.data)
183+
self.accum_response_length.add_(response_length.data)
165184
self.accum_count += 1
166185
if need_update:
167186
self.optimizer.step()
@@ -171,21 +190,38 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
171190
print(
172191
"Loss:",
173192
self.accum_loss.item() / self.accum_count,
174-
"Reward:",
193+
"\nReward:",
175194
self.accum_reward.item() / self.accum_count,
176-
"KL:",
195+
"\nFormat Reward:",
196+
self.accum_format_reward.item() / self.accum_count,
197+
"\nAcc Reward:",
198+
self.accum_acc_reward.item() / self.accum_count,
199+
"\nKL:",
177200
self.accum_kl.item() / self.accum_count,
201+
"\nAdvantages:",
202+
self.accum_advantages.item() / self.accum_count,
203+
"\nResponse Length:",
204+
self.accum_response_length.item() / self.accum_count,
178205
)
179206
self.wandb_run.log(
180207
{
181208
"train/loss": self.accum_loss.item() / self.accum_count,
182209
"train/reward": self.accum_reward.item() / self.accum_count,
210+
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
211+
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
183212
"train/kl": self.accum_kl.item() / self.accum_count,
213+
"train/advantages": self.accum_advantages.item() / self.accum_count,
214+
"train/response_length": self.accum_response_length.item() / self.accum_count,
184215
}
185216
)
186217
self.accum_loss.zero_()
187218
self.accum_reward.zero_()
219+
self.accum_acc_reward.zero_()
220+
self.accum_format_reward.zero_()
188221
self.accum_kl.zero_()
222+
self.accum_advantages.zero_()
223+
self.accum_response_length.zero_()
224+
189225
self.accum_count = 0
190226
return loss_scalar
191227

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,10 @@ def forward(
2626
) -> torch.Tensor:
2727
skip = False
2828
if action_mask is None:
29-
ratio_ = (log_probs - old_log_probs).exp()
29+
ratio = (log_probs - log_probs.detach()).exp()
3030
else:
31-
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
31+
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
3232

33-
# note that if dropout is disabled (recommanded), ratio will always be 1.
34-
if ratio_.mean() > self.skip_threshold:
35-
skip = True
36-
37-
ratio = ratio_.clamp(0.0, 10.0)
3833
surr1 = ratio * advantages
3934
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
4035
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
@@ -44,4 +39,4 @@ def forward(
4439
else:
4540
loss = loss.mean(dim=1)
4641
loss = loss.mean()
47-
return loss, skip, ratio_.max()
42+
return loss, skip, ratio.max()

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44

55

66
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
7+
format_score = 1.0
8+
acc_score = 9.0
79
tokenizer = kwargs["tokenizer"]
8-
reward = torch.tensor(0.0).to(input_ids.device)
10+
reward = torch.tensor(0.0)
11+
format_reward = torch.tensor(0.0)
12+
acc_reward = torch.tensor(0.0)
913
s, e = response_idx[0], response_idx[1]
1014
if gt_answer is None:
1115
return reward
@@ -15,13 +19,21 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
1519
final_answer, processed_str = extract_solution(decoded_final_answer)
1620

1721
format_valid = validate_response_structure(processed_str, kwargs["tags"])
18-
if not format_valid:
19-
return reward
20-
else:
21-
reward += 1.0
22-
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
23-
reward = reward + 2.0
24-
return reward
22+
23+
# Check format accuracy
24+
if format_valid:
25+
format_reward += format_score
26+
reward += format_score
27+
28+
# Check answer accuracy
29+
if (
30+
final_answer is not None
31+
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
32+
):
33+
acc_reward += acc_score
34+
reward += acc_score
35+
36+
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
2537

2638

2739
def gsm8k_reward_fn(input_ids, **kwargs):

applications/ColossalChat/coati/distributed/reward/verifiable_reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __call__(
2121
# Get batch size
2222
bs = input_ids.size(0)
2323
# Initialize reward
24-
rewards = torch.zeros(bs, device=input_ids.device)
24+
rewards = torch.zeros((bs, 3), device=input_ids.device)
2525

2626
# Loop through reward functions
2727
for reward_fn in self.reward_fns:

colossalai/shardformer/policies/qwen2.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Linear1D_Row,
1212
LinearWithGradAccum,
1313
PaddingEmbedding,
14+
PaddingLMHead,
1415
RMSNorm,
1516
VocabParallelEmbedding1D,
1617
VocabParallelLMHead1D,
@@ -449,13 +450,18 @@ def module_policy(self):
449450
sub_module_replacement=[
450451
SubModuleReplacementDescription(
451452
suffix="lm_head",
452-
target_module=LinearWithGradAccum,
453-
kwargs=dict(
454-
gather_output=not self.shard_config.parallel_output,
455-
fp8_communication=self.shard_config.fp8_communication,
456-
use_zbv=use_zbv,
457-
),
458-
)
453+
target_module=PaddingLMHead,
454+
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
455+
),
456+
SubModuleReplacementDescription(
457+
suffix="lm_head",
458+
target_module=VocabParallelLMHead1D,
459+
kwargs={
460+
"gather_output": not self.shard_config.parallel_output,
461+
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
462+
"fp8_communication": self.shard_config.fp8_communication,
463+
},
464+
),
459465
],
460466
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
461467
)

0 commit comments

Comments
 (0)