Skip to content

Commit a246bf2

Browse files
TongLi3701Tong Li
authored andcommitted
add overlength sample count (#6332)
Co-authored-by: Tong Li <[email protected]>
1 parent 6051001 commit a246bf2

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ def __init__(
8484
self.project_name = project_name
8585
self.effective_sample_count = 0
8686
self.effective_prompt_count = 0
87+
<<<<<<< HEAD
88+
=======
89+
self.total_sample_count = 0
90+
self.overlength_samples = 0
91+
self.total_overlength_samples = 0
92+
>>>>>>> c8b368c2 (add overlength sample count (#6332))
8793
self.project_name = project_name
8894
self.run_name = run_name
8995
self.wandb_group_name = wandb_group_name
@@ -207,11 +213,25 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
207213

208214
# filter out overlength samples
209215
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
216+
old_loss_mask = loss_mask.clone()
210217
loss_mask = torch.logical_and(
211218
loss_mask,
212219
action_mask[:, -1] == False,
213220
)
214-
self.effective_prompt_count += group_reward.size(0) * self.dp_size
221+
222+
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
223+
self.overlength_samples = all_reduce_sum(
224+
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
225+
)
226+
self.total_overlength_samples += self.overlength_samples.item()
227+
228+
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
229+
230+
# [minibatch_size] -> calculate the number of effective prompts
231+
effective_prompts_mask = prompt_level_mask.any(dim=1)
232+
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
233+
self.effective_prompt_count += effective_prompts.item()
234+
excessive_prompts_idx = None
215235

216236
mean_kl, mean_loss = [], []
217237

@@ -428,9 +448,18 @@ def _criterion(outputs, inputs):
428448
self.optimizer.step()
429449
self.optimizer.zero_grad()
430450
self.global_step += 1
451+
<<<<<<< HEAD
431452
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
432453
self.effective_prompt_count = 0
433454
self.effective_sample_count = 0
455+
=======
456+
sample_utilization = self.effective_sample_count / self.total_sample_count
457+
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
458+
self.effective_prompt_count = 0
459+
self.effective_sample_count = 0
460+
self.total_sample_count = 0
461+
self.total_overlength_samples = 0
462+
>>>>>>> c8b368c2 (add overlength sample count (#6332))
434463
loss_scalar = self.accum_loss.item()
435464
if not self.plugin.pp_size > 1 or (
436465
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@@ -458,6 +487,7 @@ def _criterion(outputs, inputs):
458487
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
459488
f"Response Length: {raw_batch_response_len_mean:.4f}",
460489
f"Sample_utilization: {sample_utilization:.4f}",
490+
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
461491
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
462492
print("\n".join(to_log_msg))
463493
metrics = {
@@ -469,6 +499,7 @@ def _criterion(outputs, inputs):
469499
"train/advantages": self.accum_advantages.item() / self.accum_count,
470500
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
471501
"train/sample_utilization": sample_utilization,
502+
"train/percentage_overlength_samples": overlength_samples_percentage,
472503
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
473504
}
474505
if self.policy_loss_fn.beta > 0:

0 commit comments

Comments
 (0)