@@ -91,9 +91,6 @@ def __init__(
91
91
self .project_name = project_name
92
92
self .effective_sample_count = 0
93
93
self .effective_prompt_count = 0
94
- self .total_sample_count = 0
95
- self .overlength_samples = 0
96
- self .total_overlength_samples = 0
97
94
self .project_name = project_name
98
95
self .run_name = run_name
99
96
self .wandb_group_name = wandb_group_name
@@ -207,7 +204,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
207
204
208
205
# filter out overlength samples
209
206
if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
210
- old_loss_mask = loss_mask .clone ()
211
207
loss_mask = torch .logical_and (
212
208
loss_mask ,
213
209
action_mask [:, - 1 ] == False ,
@@ -225,15 +221,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
225
221
group_ans_acc_mean < self .filter_range [1 ],
226
222
),
227
223
)
228
- self .total_overlength_samples += self .overlength_samples .item ()
229
-
230
- prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
231
-
232
- # [minibatch_size] -> calculate the number of effective prompts
233
- effective_prompts_mask = prompt_level_mask .any (dim = 1 )
234
- effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
235
- self .effective_prompt_count += effective_prompts .item ()
236
- excessive_prompts_idx = None
224
+ self .effective_prompt_count += group_reward .size (0 ) * self .dp_size
237
225
238
226
mean_kl , mean_loss = [], []
239
227
@@ -250,8 +238,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
250
238
pbar .set_postfix (
251
239
{
252
240
"Global Step" : self .global_step ,
253
- "Effective prompts" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } " ,
254
- "Effective samples" : f"{ self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
241
+ "Gradient Accumulation on" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } effective prompts, { self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } effective samples" ,
255
242
}
256
243
)
257
244
@@ -477,12 +464,10 @@ def _criterion(outputs, inputs):
477
464
self .optimizer .step ()
478
465
self .optimizer .zero_grad ()
479
466
self .global_step += 1
480
- sample_utilization = self . effective_sample_count / self . total_sample_count
481
- overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
467
+ # no need to run all reduce as raw_train_batch_* are not splited across dp rank
468
+ sample_utilization = self .effective_sample_count / len ( self .raw_train_batch_reward ) / self . num_generations
482
469
self .effective_prompt_count = 0
483
470
self .effective_sample_count = 0
484
- self .total_sample_count = 0
485
- self .total_overlength_samples = 0
486
471
loss_scalar = self .accum_loss .item ()
487
472
if not self .plugin .pp_size > 1 or (
488
473
self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
@@ -545,4 +530,4 @@ def state_dict(self):
545
530
model = self .policy_model .unwrap ()
546
531
state_dict = model .state_dict ()
547
532
state_dict ["consumer_global_step" ] = torch .tensor ([self .global_step ], device = self .device )
548
- return state_dict
533
+ return state_dict
0 commit comments