@@ -91,9 +91,6 @@ def __init__(
9191 self .project_name = project_name
9292 self .effective_sample_count = 0
9393 self .effective_prompt_count = 0
94- self .total_sample_count = 0
95- self .overlength_samples = 0
96- self .total_overlength_samples = 0
9794 self .project_name = project_name
9895 self .run_name = run_name
9996 self .wandb_group_name = wandb_group_name
@@ -207,7 +204,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
207204
208205 # filter out overlength samples
209206 if self .filter_truncated_response and action_mask .size (1 ) == self .max_length :
210- old_loss_mask = loss_mask .clone ()
211207 loss_mask = torch .logical_and (
212208 loss_mask ,
213209 action_mask [:, - 1 ] == False ,
@@ -225,15 +221,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
225221 group_ans_acc_mean < self .filter_range [1 ],
226222 ),
227223 )
228- self .total_overlength_samples += self .overlength_samples .item ()
229-
230- prompt_level_mask = loss_mask .view (self .minibatch_size , self .num_generations )
231-
232- # [minibatch_size] -> calculate the number of effective prompts
233- effective_prompts_mask = prompt_level_mask .any (dim = 1 )
234- effective_prompts = all_reduce_sum (torch .sum (effective_prompts_mask ), self .plugin )
235- self .effective_prompt_count += effective_prompts .item ()
236- excessive_prompts_idx = None
224+ self .effective_prompt_count += group_reward .size (0 ) * self .dp_size
237225
238226 mean_kl , mean_loss = [], []
239227
@@ -250,8 +238,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
250238 pbar .set_postfix (
251239 {
252240 "Global Step" : self .global_step ,
253- "Effective prompts" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } " ,
254- "Effective samples" : f"{ self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
241+ "Gradient Accumulation on" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } effective prompts, { self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } effective samples" ,
255242 }
256243 )
257244
@@ -477,12 +464,10 @@ def _criterion(outputs, inputs):
477464 self .optimizer .step ()
478465 self .optimizer .zero_grad ()
479466 self .global_step += 1
480- sample_utilization = self . effective_sample_count / self . total_sample_count
481- overlength_samples_percentage = self .total_overlength_samples / self .total_sample_count
467+ # no need to run all reduce as raw_train_batch_* are not splited across dp rank
468+ sample_utilization = self .effective_sample_count / len ( self .raw_train_batch_reward ) / self . num_generations
482469 self .effective_prompt_count = 0
483470 self .effective_sample_count = 0
484- self .total_sample_count = 0
485- self .total_overlength_samples = 0
486471 loss_scalar = self .accum_loss .item ()
487472 if not self .plugin .pp_size > 1 or (
488473 self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
@@ -545,4 +530,4 @@ def state_dict(self):
545530 model = self .policy_model .unwrap ()
546531 state_dict = model .state_dict ()
547532 state_dict ["consumer_global_step" ] = torch .tensor ([self .global_step ], device = self .device )
548- return state_dict
533+ return state_dict
0 commit comments