Skip to content

Commit c8b368c

Browse files
TongLi3701Tong Li
andauthored
add overlength sample count (#6332)
Co-authored-by: Tong Li <[email protected]>
1 parent de2ad3b commit c8b368c

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def __init__(
8585
self.effective_sample_count = 0
8686
self.effective_prompt_count = 0
8787
self.total_sample_count = 0
88+
self.overlength_samples = 0
89+
self.total_overlength_samples = 0
8890
self.project_name = project_name
8991
self.run_name = run_name
9092
self.wandb_group_name = wandb_group_name
@@ -227,10 +229,18 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
227229

228230
# filter out overlength samples
229231
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
232+
old_loss_mask = loss_mask.clone()
230233
loss_mask = torch.logical_and(
231234
loss_mask,
232235
action_mask[:, -1] == False,
233236
)
237+
238+
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
239+
self.overlength_samples = all_reduce_sum(
240+
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
241+
)
242+
self.total_overlength_samples += self.overlength_samples.item()
243+
234244
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
235245

236246
# [minibatch_size] -> calculate the number of effective prompts
@@ -484,9 +494,11 @@ def _criterion(outputs, inputs):
484494
self.optimizer.zero_grad()
485495
self.global_step += 1
486496
sample_utilization = self.effective_sample_count / self.total_sample_count
497+
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
487498
self.effective_prompt_count = 0
488499
self.effective_sample_count = 0
489500
self.total_sample_count = 0
501+
self.total_overlength_samples = 0
490502
loss_scalar = self.accum_loss.item()
491503
if not self.plugin.pp_size > 1 or (
492504
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@@ -502,6 +514,7 @@ def _criterion(outputs, inputs):
502514
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
503515
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
504516
f"Sample_utilization: {sample_utilization:.4f}",
517+
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
505518
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
506519
print("\n".join(to_log_msg))
507520
metrics = {
@@ -513,6 +526,7 @@ def _criterion(outputs, inputs):
513526
"train/advantages": self.accum_advantages.item() / self.accum_count,
514527
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
515528
"train/sample_utilization": sample_utilization,
529+
"train/percentage_overlength_samples": overlength_samples_percentage,
516530
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
517531
}
518532
if self.policy_loss_fn.beta > 0:

0 commit comments

Comments
 (0)