Skip to content

Commit 941c767

Browse files
TongLi3701Tong Li
authored andcommitted
add overlength sample count (#6332)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
1 parent 5c4c8a6 commit 941c767

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def __init__(
8585
self.effective_sample_count = 0
8686
self.effective_prompt_count = 0
8787
self.total_sample_count = 0
88+
self.overlength_samples = 0
89+
self.total_overlength_samples = 0
8890
self.project_name = project_name
8991
self.run_name = run_name
9092
self.wandb_group_name = wandb_group_name
@@ -208,11 +210,25 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
208210

209211
# filter out overlength samples
210212
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
213+
old_loss_mask = loss_mask.clone()
211214
loss_mask = torch.logical_and(
212215
loss_mask,
213216
action_mask[:, -1] == False,
214217
)
215-
self.effective_prompt_count += group_reward.size(0) * self.dp_size
218+
219+
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
220+
self.overlength_samples = all_reduce_sum(
221+
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
222+
)
223+
self.total_overlength_samples += self.overlength_samples.item()
224+
225+
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
226+
227+
# [minibatch_size] -> calculate the number of effective prompts
228+
effective_prompts_mask = prompt_level_mask.any(dim=1)
229+
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
230+
self.effective_prompt_count += effective_prompts.item()
231+
excessive_prompts_idx = None
216232

217233
mean_kl, mean_loss = [], []
218234

@@ -432,9 +448,11 @@ def _criterion(outputs, inputs):
432448
self.global_step += 1
433449
# no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
434450
sample_utilization = self.effective_sample_count / self.total_sample_count
451+
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
435452
self.effective_prompt_count = 0
436453
self.effective_sample_count = 0
437454
self.total_sample_count = 0
455+
self.total_overlength_samples = 0
438456
loss_scalar = self.accum_loss.item()
439457
if not self.plugin.pp_size > 1 or (
440458
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@@ -462,6 +480,7 @@ def _criterion(outputs, inputs):
462480
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
463481
f"Response Length: {raw_batch_response_len_mean:.4f}",
464482
f"Sample_utilization: {sample_utilization:.4f}",
483+
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
465484
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
466485
print("\n".join(to_log_msg))
467486
metrics = {
@@ -473,6 +492,7 @@ def _criterion(outputs, inputs):
473492
"train/advantages": self.accum_advantages.item() / self.accum_count,
474493
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
475494
"train/sample_utilization": sample_utilization,
495+
"train/percentage_overlength_samples": overlength_samples_percentage,
476496
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
477497
}
478498
if self.policy_loss_fn.beta > 0:
@@ -483,16 +503,12 @@ def _criterion(outputs, inputs):
483503
self.accum_kl.zero_()
484504
self.accum_advantages.zero_()
485505
self.accum_count = 0
486-
<<<<<<< HEAD
487-
return loss_scalar
488-
=======
489506

490507
if excessive_prompts_idx is not None:
491508
# All gather excessive prompts index across DP ranks.
492509
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
493510
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
494511
return loss_scalar, excessive_prompts_idx
495-
>>>>>>> 3c42c0ce (Merge pull request #6309 from hpcaitech/grpo-eval-dev)
496512
else:
497513
return None
498514

0 commit comments

Comments
 (0)