Skip to content

Commit 47d6493

Browse files
author
Tong Li
committed
add response length
1 parent abca66e commit 47d6493

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
self.accum_format_reward = torch.zeros(1, device=self.device)
6060
self.accum_acc_reward = torch.zeros(1, device=self.device)
6161
self.accum_advantages = torch.zeros(1, device=self.device)
62+
self.accum_response_length = torch.zeros(1, device=self.device)
6263
self.accum_count = 0
6364

6465
# Reference model is initialized from policy model.
@@ -83,7 +84,7 @@ def __init__(
8384
self.policy_loss_fn = PolicyLoss()
8485
self.global_step = 0
8586
if use_wandb and self.rank == 0:
86-
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
87+
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
8788

8889
def setup(self):
8990
super().setup()
@@ -109,6 +110,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
109110
action_mask = data["action_mask"]
110111
num_action = action_mask.shape[1]
111112
old_action_log_probs = data["action_log_probs"]
113+
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
112114

113115
need_update = (step_idx + 1) % self.num_microbatches == 0
114116

@@ -168,13 +170,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
168170
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
169171
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
170172
advantages = all_reduce_mean(advantages.mean(), self.plugin)
173+
response_length = all_reduce_mean(response_length.mean(), self.plugin)
171174
# Calculate accumulate value.
172175
self.accum_loss.add_(loss.data)
173176
self.accum_reward.add_(reward.data)
174177
self.accum_kl.add_(kl.data)
175178
self.accum_format_reward.add_(format_reward.data)
176179
self.accum_acc_reward.add_(acc_reward.data)
177180
self.accum_advantages.add_(advantages.data)
181+
self.accum_response_length.add_(response_length.data)
178182
self.accum_count += 1
179183
if need_update:
180184
self.optimizer.step()
@@ -184,32 +188,38 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
184188
print(
185189
"Loss:",
186190
self.accum_loss.item() / self.accum_count,
187-
"Reward:",
191+
"\nReward:",
188192
self.accum_reward.item() / self.accum_count,
189-
"KL:",
190-
self.accum_kl.item() / self.accum_count,
191-
"Format Reward:",
193+
"\nFormat Reward:",
192194
self.accum_format_reward.item() / self.accum_count,
193-
"Acc Reward:",
195+
"\nAcc Reward:",
194196
self.accum_acc_reward.item() / self.accum_count,
195-
"Advantages:",
197+
"\nKL:",
198+
self.accum_kl.item() / self.accum_count,
199+
"\nAdvantages:",
196200
self.accum_advantages.item() / self.accum_count,
201+
"\nResponse Length:",
202+
self.accum_response_length.item() / self.accum_count,
197203
)
198204
self.wandb_run.log(
199205
{
200206
"train/loss": self.accum_loss.item() / self.accum_count,
201207
"train/reward": self.accum_reward.item() / self.accum_count,
202-
"train/kl": self.accum_kl.item() / self.accum_count,
203208
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
204209
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
210+
"train/kl": self.accum_kl.item() / self.accum_count,
205211
"train/advantages": self.accum_advantages.item() / self.accum_count,
212+
"train/response_length": self.accum_response_length.item() / self.accum_count,
206213
}
207214
)
208215
self.accum_loss.zero_()
209216
self.accum_reward.zero_()
210-
self.accum_kl.zero_()
211217
self.accum_acc_reward.zero_()
212218
self.accum_format_reward.zero_()
219+
self.accum_kl.zero_()
220+
self.accum_advantages.zero_()
221+
self.accum_response_length.zero_()
222+
213223
self.accum_count = 0
214224
return loss_scalar
215225

0 commit comments

Comments
 (0)