@@ -84,6 +84,12 @@ def __init__(
84
84
self .project_name = project_name
85
85
self .effective_sample_count = 0
86
86
self .effective_prompt_count = 0
87
+ < << << << HEAD
88
+ == == == =
89
+ self .total_sample_count = 0
90
+ self .overlength_samples = 0
91
+ self .total_overlength_samples = 0
92
+ >> >> >> > c8b368c2 (add overlength sample count (#6332))
87
93
self .project_name = project_name
88
94
self .run_name = run_name
89
95
self .wandb_group_name = wandb_group_name
@@ -207,11 +213,25 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
207
213
208
214
# filter out overlength samples
209
215
if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
216
+ old_loss_mask = loss_mask .clone ()
210
217
loss_mask = torch .logical_and (
211
218
loss_mask ,
212
219
action_mask [:, - 1 ] == False ,
213
220
)
214
- self .effective_prompt_count += group_reward .size (0 ) * self .dp_size
221
+
222
+ self .overlength_samples = (old_loss_mask & ~ loss_mask ).sum ().item ()
223
+ self .overlength_samples = all_reduce_sum (
224
+ torch .tensor (self .overlength_samples , device = loss_mask .device ), self .plugin
225
+ )
226
+ self .total_overlength_samples + = self .overlength_samples .item ()
227
+
228
+ prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
229
+
230
+ # [minibatch_size] -> calculate the number of effective prompts
231
+ effective_prompts_mask = prompt_level_mask .any (dim = 1 )
232
+ effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
233
+ self .effective_prompt_count + = effective_prompts .item ()
234
+ excessive_prompts_idx = None
215
235
216
236
mean_kl , mean_loss = [], []
217
237
@@ -428,9 +448,18 @@ def _criterion(outputs, inputs):
428
448
self .optimizer .step ()
429
449
self .optimizer .zero_grad ()
430
450
self .global_step + = 1
451
+ << << < << HEAD
431
452
sample_utilization = self .effective_sample_count / len (self .raw_train_batch_reward ) / self .num_generations
432
453
self .effective_prompt_count = 0
433
454
self .effective_sample_count = 0
455
+ == == == =
456
+ sample_utilization = self .effective_sample_count / self .total_sample_count
457
+ overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
458
+ self .effective_prompt_count = 0
459
+ self .effective_sample_count = 0
460
+ self .total_sample_count = 0
461
+ self .total_overlength_samples = 0
462
+ >> >> > >> c8b368c2 (add overlength sample count (#6332))
434
463
loss_scalar = self .accum_loss .item ()
435
464
if not self .plugin .pp_size > 1 or (
436
465
self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
@@ -458,6 +487,7 @@ def _criterion(outputs, inputs):
458
487
f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
459
488
f"Response Length: { raw_batch_response_len_mean :.4f} " ,
460
489
f"Sample_utilization: { sample_utilization :.4f} " ,
490
+ f"Percentage of overlength samples: { overlength_samples_percentage :.4f} " ,
461
491
] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
462
492
print ("\n " .join (to_log_msg ))
463
493
metrics = {
@@ -469,6 +499,7 @@ def _criterion(outputs, inputs):
469
499
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
470
500
"train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
471
501
"train/sample_utilization" : sample_utilization ,
502
+ "train/percentage_overlength_samples" : overlength_samples_percentage ,
472
503
"rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
473
504
}
474
505
if self .policy_loss_fn .beta > 0 :
0 commit comments