Skip to content

Commit 9d9d516

Browse files
author
Tong Li
committed
update grpo
1 parent eb6337f commit 9d9d516

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def __init__(
5656
self.accum_loss = torch.zeros(1, device=self.device)
5757
self.accum_reward = torch.zeros(1, device=self.device)
5858
self.accum_kl = torch.zeros(1, device=self.device)
59+
self.accum_format_reward = torch.zeros(1, device=self.device)
60+
self.accum_acc_reward = torch.zeros(1, device=self.device)
61+
self.accum_advantages = torch.zeros(1, device=self.device)
5962
self.accum_count = 0
6063

6164
# Reference model is initialized from policy model.
@@ -131,9 +134,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
131134
)
132135
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
133136

134-
reward = self.reward_model(
137+
reward_group = self.reward_model(
135138
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
136139
)
140+
141+
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
142+
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
143+
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
144+
137145
# [batch_size, num_generations]
138146
group_reward = reward.view(-1, self.num_generations)
139147

@@ -157,9 +165,16 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
157165
loss = all_reduce_mean(loss, self.plugin)
158166
reward = all_reduce_mean(reward.mean(), self.plugin)
159167
kl = all_reduce_mean(kl.mean(), self.plugin)
168+
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
169+
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
170+
advantages = all_reduce_mean(advantages.mean(), self.plugin)
171+
# Calculate accumulate value.
160172
self.accum_loss.add_(loss.data)
161173
self.accum_reward.add_(reward.data)
162174
self.accum_kl.add_(kl.data)
175+
self.accum_format_reward.add_(format_reward.data)
176+
self.accum_acc_reward.add_(acc_reward.data)
177+
self.accum_advantages.add_(advantages.data)
163178
self.accum_count += 1
164179
if need_update:
165180
self.optimizer.step()
@@ -173,17 +188,28 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
173188
self.accum_reward.item() / self.accum_count,
174189
"KL:",
175190
self.accum_kl.item() / self.accum_count,
191+
"Format Reward:",
192+
self.accum_format_reward.item() / self.accum_count,
193+
"Acc Reward:",
194+
self.accum_acc_reward.item() / self.accum_count,
195+
"Advantages:",
196+
self.accum_advantages.item() / self.accum_count,
176197
)
177198
self.wandb_run.log(
178199
{
179200
"train/loss": self.accum_loss.item() / self.accum_count,
180201
"train/reward": self.accum_reward.item() / self.accum_count,
181202
"train/kl": self.accum_kl.item() / self.accum_count,
203+
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
204+
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
205+
"train/advantages": self.accum_advantages.item() / self.accum_count,
182206
}
183207
)
184208
self.accum_loss.zero_()
185209
self.accum_reward.zero_()
186210
self.accum_kl.zero_()
211+
self.accum_acc_reward.zero_()
212+
self.accum_format_reward.zero_()
187213
self.accum_count = 0
188214
return loss_scalar
189215

0 commit comments

Comments
 (0)