@@ -85,6 +85,8 @@ def __init__(
8585 self .effective_sample_count = 0
8686 self .effective_prompt_count = 0
8787 self .total_sample_count = 0
88+ self .overlength_samples = 0
89+ self .total_overlength_samples = 0
8890 self .project_name = project_name
8991 self .run_name = run_name
9092 self .wandb_group_name = wandb_group_name
@@ -208,11 +210,25 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
208210
209211 # filter out overlength samples
210212 if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
213+ old_loss_mask = loss_mask .clone ()
211214 loss_mask = torch .logical_and (
212215 loss_mask ,
213216 action_mask [:, - 1 ] == False ,
214217 )
215- self .effective_prompt_count += group_reward .size (0 ) * self .dp_size
218+
219+ self .overlength_samples = (old_loss_mask & ~ loss_mask ).sum ().item ()
220+ self .overlength_samples = all_reduce_sum (
221+ torch .tensor (self .overlength_samples , device = loss_mask .device ), self .plugin
222+ )
223+ self .total_overlength_samples += self .overlength_samples .item ()
224+
225+ prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
226+
227+ # [minibatch_size] -> calculate the number of effective prompts
228+ effective_prompts_mask = prompt_level_mask .any (dim = 1 )
229+ effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
230+ self .effective_prompt_count += effective_prompts .item ()
231+ excessive_prompts_idx = None
216232
217233 mean_kl , mean_loss = [], []
218234
@@ -432,9 +448,11 @@ def _criterion(outputs, inputs):
432448 self .global_step += 1
433449 # no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
434450 sample_utilization = self .effective_sample_count / self .total_sample_count
451+ overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
435452 self .effective_prompt_count = 0
436453 self .effective_sample_count = 0
437454 self .total_sample_count = 0
455+ self .total_overlength_samples = 0
438456 loss_scalar = self .accum_loss .item ()
439457 if not self .plugin .pp_size > 1 or (
440458 self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
@@ -462,6 +480,7 @@ def _criterion(outputs, inputs):
462480 f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
463481 f"Response Length: { raw_batch_response_len_mean :.4f} " ,
464482 f"Sample_utilization: { sample_utilization :.4f} " ,
483+ f"Percentage of overlength samples: { overlength_samples_percentage :.4f} " ,
465484 ] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
466485 print ("\n " .join (to_log_msg ))
467486 metrics = {
@@ -473,6 +492,7 @@ def _criterion(outputs, inputs):
473492 "train/advantages" : self .accum_advantages .item () / self .accum_count ,
474493 "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
475494 "train/sample_utilization" : sample_utilization ,
495+ "train/percentage_overlength_samples" : overlength_samples_percentage ,
476496 "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
477497 }
478498 if self .policy_loss_fn .beta > 0 :
@@ -483,16 +503,12 @@ def _criterion(outputs, inputs):
483503 self .accum_kl .zero_ ()
484504 self .accum_advantages .zero_ ()
485505 self .accum_count = 0
486- < << << << HEAD
487- return loss_scalar
488- == == == =
489506
490507 if excessive_prompts_idx is not None :
491508 # All gather excessive prompts index across DP ranks.
492509 excessive_prompts_idx = [idx + self .dp_rank * self .minibatch_size for idx in excessive_prompts_idx ]
493510 excessive_prompts_idx = all_gather_tensors (excessive_prompts_idx , self .plugin )
494511 return loss_scalar , excessive_prompts_idx
495- >> >> >> > 3 c42c0ce (Merge pull request #6309 from hpcaitech/grpo-eval-dev)
496512 else :
497513 return None
498514
0 commit comments