@@ -85,6 +85,8 @@ def __init__(
85
85
self .effective_sample_count = 0
86
86
self .effective_prompt_count = 0
87
87
self .total_sample_count = 0
88
+ self .overlength_samples = 0
89
+ self .total_overlength_samples = 0
88
90
self .project_name = project_name
89
91
self .run_name = run_name
90
92
self .wandb_group_name = wandb_group_name
@@ -227,10 +229,18 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
227
229
228
230
# filter out overlength samples
229
231
if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
232
+ old_loss_mask = loss_mask .clone ()
230
233
loss_mask = torch .logical_and (
231
234
loss_mask ,
232
235
action_mask [:, - 1 ] == False ,
233
236
)
237
+
238
+ self .overlength_samples = (old_loss_mask & ~ loss_mask ).sum ().item ()
239
+ self .overlength_samples = all_reduce_sum (
240
+ torch .tensor (self .overlength_samples , device = loss_mask .device ), self .plugin
241
+ )
242
+ self .total_overlength_samples += self .overlength_samples .item ()
243
+
234
244
prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
235
245
236
246
# [minibatch_size] -> calculate the number of effective prompts
@@ -484,9 +494,11 @@ def _criterion(outputs, inputs):
484
494
self .optimizer .zero_grad ()
485
495
self .global_step += 1
486
496
sample_utilization = self .effective_sample_count / self .total_sample_count
497
+ overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
487
498
self .effective_prompt_count = 0
488
499
self .effective_sample_count = 0
489
500
self .total_sample_count = 0
501
+ self .total_overlength_samples = 0
490
502
loss_scalar = self .accum_loss .item ()
491
503
if not self .plugin .pp_size > 1 or (
492
504
self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
@@ -502,6 +514,7 @@ def _criterion(outputs, inputs):
502
514
f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
503
515
f"Response Length: { self .accum_response_length .item () / self .accum_count :.4f} " ,
504
516
f"Sample_utilization: { sample_utilization :.4f} " ,
517
+ f"Percentage of overlength samples: { overlength_samples_percentage :.4f} " ,
505
518
] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
506
519
print ("\n " .join (to_log_msg ))
507
520
metrics = {
@@ -513,6 +526,7 @@ def _criterion(outputs, inputs):
513
526
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
514
527
"train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
515
528
"train/sample_utilization" : sample_utilization ,
529
+ "train/percentage_overlength_samples" : overlength_samples_percentage ,
516
530
"rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
517
531
}
518
532
if self .policy_loss_fn .beta > 0 :
0 commit comments