@@ -211,6 +211,17 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
211211 loss_mask ,
212212 action_mask [:, - 1 ] == False ,
213213 )
214+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , False )== False :
215+ # filter out samples with reward outside the range
216+ # if dynamic batching is enabled, we filter out out of range groups before training
217+ group_ans_acc_mean = ans_acc .view (- 1 , self .num_generations ).mean (dim = 1 ).repeat_interleave (self .num_generations , dim = - 1 )
218+ loss_mask = torch .logical_and (
219+ loss_mask ,
220+ torch .logical_and (
221+ group_ans_acc_mean > self .filter_range [0 ],
222+ group_ans_acc_mean < self .filter_range [1 ],
223+ ),
224+ )
214225 self .effective_prompt_count += group_reward .size (0 ) * self .dp_size
215226
216227 mean_kl , mean_loss = [], []
@@ -229,8 +240,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
229240 pbar .set_postfix (
230241 {
231242 "Global Step" : self .global_step ,
232- "Effective prompts" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } " ,
233- "Effective samples" : f"{ self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
243+ "Gradient Accumulation on" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } effective prompts, { self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } effective samples" ,
234244 }
235245 )
236246
@@ -428,6 +438,7 @@ def _criterion(outputs, inputs):
428438 self .optimizer .step ()
429439 self .optimizer .zero_grad ()
430440 self .global_step += 1
441+ # no need to run all reduce as raw_train_batch_* are not splited across dp rank
431442 sample_utilization = self .effective_sample_count / len (self .raw_train_batch_reward ) / self .num_generations
432443 self .effective_prompt_count = 0
433444 self .effective_sample_count = 0
@@ -438,14 +449,12 @@ def _criterion(outputs, inputs):
438449 if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
439450 self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
440451 ):
441- raw_batch_reward_mean = sum (self .raw_train_batch_reward ) / len (self .raw_train_batch_reward )
442- raw_batch_format_acc_mean = sum (self .raw_train_batch_format_acc ) / len (
443- self .raw_train_batch_format_acc
444- )
445- raw_batch_ans_acc_mean = sum (self .raw_train_batch_ans_acc ) / len (self .raw_train_batch_ans_acc )
446- raw_batch_response_len_mean = sum (self .raw_train_batch_response_len ) / len (
447- self .raw_train_batch_response_len
448- )
452+ raw_batch_reward_mean = torch .cat (self .raw_train_batch_reward , dim = 0 ).mean ().cpu ().item ()
453+ raw_batch_format_acc_mean = torch .cat (self .raw_train_batch_format_acc , dim = 0 ).mean ().cpu ().item ()
454+ raw_batch_ans_acc_mean = torch .cat (self .raw_train_batch_ans_acc , dim = 0 ).mean ().cpu ().item ()
455+ raw_batch_response_len = torch .cat (self .raw_train_batch_response_len , dim = 0 )
456+ raw_batch_response_len_mean = raw_batch_response_len .mean ().cpu ().item ()
457+ overlength_samples_ratio = (raw_batch_response_len >= action_mask .size (- 1 )).to (float ).mean ().cpu ().item () # not an exact figure, but a close estimate
449458 self .raw_train_batch_reward = []
450459 self .raw_train_batch_format_acc = []
451460 self .raw_train_batch_ans_acc = []
@@ -458,6 +467,7 @@ def _criterion(outputs, inputs):
458467 f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
459468 f"Response Length: { raw_batch_response_len_mean :.4f} " ,
460469 f"Sample_utilization: { sample_utilization :.4f} " ,
470+ f"Overlength samples ratio: { overlength_samples_ratio :.4f} " ,
461471 ] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
462472 print ("\n " .join (to_log_msg ))
463473 metrics = {
@@ -469,6 +479,7 @@ def _criterion(outputs, inputs):
469479 "train/advantages" : self .accum_advantages .item () / self .accum_count ,
470480 "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
471481 "train/sample_utilization" : sample_utilization ,
482+ "train/overlength_samples_ratio" : overlength_samples_ratio ,
472483 "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
473484 }
474485 if self .policy_loss_fn .beta > 0 :
0 commit comments