@@ -85,6 +85,8 @@ def __init__(
8585 self .effective_sample_count = 0
8686 self .effective_prompt_count = 0
8787 self .total_sample_count = 0
88+ self .overlength_samples = 0
89+ self .total_overlength_samples = 0
8890 self .project_name = project_name
8991 self .run_name = run_name
9092 self .wandb_group_name = wandb_group_name
@@ -227,10 +229,18 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
227229
228230 # filter out overlength samples
229231 if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
232+ old_loss_mask = loss_mask .clone ()
230233 loss_mask = torch .logical_and (
231234 loss_mask ,
232235 action_mask [:, - 1 ] == False ,
233236 )
237+
238+ self .overlength_samples = (old_loss_mask & ~ loss_mask ).sum ().item ()
239+ self .overlength_samples = all_reduce_sum (
240+ torch .tensor (self .overlength_samples , device = loss_mask .device ), self .plugin
241+ )
242+ self .total_overlength_samples += self .overlength_samples .item ()
243+
234244 prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
235245
236246 # [minibatch_size] -> calculate the number of effective prompts
@@ -484,9 +494,11 @@ def _criterion(outputs, inputs):
484494 self .optimizer .zero_grad ()
485495 self .global_step += 1
486496 sample_utilization = self .effective_sample_count / self .total_sample_count
497+ overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
487498 self .effective_prompt_count = 0
488499 self .effective_sample_count = 0
489500 self .total_sample_count = 0
501+ self .total_overlength_samples = 0
490502 loss_scalar = self .accum_loss .item ()
491503 if not self .plugin .pp_size > 1 or (
492504 self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
@@ -502,6 +514,7 @@ def _criterion(outputs, inputs):
502514 f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
503515 f"Response Length: { self .accum_response_length .item () / self .accum_count :.4f} " ,
504516 f"Sample_utilization: { sample_utilization :.4f} " ,
517+ f"Percentage of overlength samples: { overlength_samples_percentage :.4f} " ,
505518 ] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
506519 print ("\n " .join (to_log_msg ))
507520 metrics = {
@@ -513,6 +526,7 @@ def _criterion(outputs, inputs):
513526 "train/advantages" : self .accum_advantages .item () / self .accum_count ,
514527 "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
515528 "train/sample_utilization" : sample_utilization ,
529+ "train/percentage_overlength_samples" : overlength_samples_percentage ,
516530 "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
517531 }
518532 if self .policy_loss_fn .beta > 0 :
0 commit comments