@@ -211,6 +211,17 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
211
211
loss_mask ,
212
212
action_mask [:, - 1 ] == False ,
213
213
)
214
+ if self .filter_range is not None and self .grpo_config .get ("dynamic_batching" , False )== False :
215
+ # filter out samples with reward outside the range
216
+ # if dynamic batching is enabled, we filter out out of range groups before training
217
+ group_ans_acc_mean = ans_acc .view (- 1 , self .num_generations ).mean (dim = 1 ).repeat_interleave (self .num_generations , dim = - 1 )
218
+ loss_mask = torch .logical_and (
219
+ loss_mask ,
220
+ torch .logical_and (
221
+ group_ans_acc_mean > self .filter_range [0 ],
222
+ group_ans_acc_mean < self .filter_range [1 ],
223
+ ),
224
+ )
214
225
self .effective_prompt_count += group_reward .size (0 ) * self .dp_size
215
226
216
227
mean_kl , mean_loss = [], []
@@ -229,8 +240,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
229
240
pbar .set_postfix (
230
241
{
231
242
"Global Step" : self .global_step ,
232
- "Effective prompts" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } " ,
233
- "Effective samples" : f"{ self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } " ,
243
+ "Gradient Accumulation on" : f"{ self .effective_prompt_count } /{ self .batch_size * self .dp_size } effective prompts, { self .effective_sample_count } /{ self .batch_size * self .dp_size * self .num_generations } effective samples" ,
234
244
}
235
245
)
236
246
@@ -428,6 +438,7 @@ def _criterion(outputs, inputs):
428
438
self .optimizer .step ()
429
439
self .optimizer .zero_grad ()
430
440
self .global_step += 1
441
+ # no need to run all reduce as raw_train_batch_* are not splited across dp rank
431
442
sample_utilization = self .effective_sample_count / len (self .raw_train_batch_reward ) / self .num_generations
432
443
self .effective_prompt_count = 0
433
444
self .effective_sample_count = 0
@@ -438,14 +449,12 @@ def _criterion(outputs, inputs):
438
449
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
439
450
self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
440
451
):
441
- raw_batch_reward_mean = sum (self .raw_train_batch_reward ) / len (self .raw_train_batch_reward )
442
- raw_batch_format_acc_mean = sum (self .raw_train_batch_format_acc ) / len (
443
- self .raw_train_batch_format_acc
444
- )
445
- raw_batch_ans_acc_mean = sum (self .raw_train_batch_ans_acc ) / len (self .raw_train_batch_ans_acc )
446
- raw_batch_response_len_mean = sum (self .raw_train_batch_response_len ) / len (
447
- self .raw_train_batch_response_len
448
- )
452
+ raw_batch_reward_mean = torch .cat (self .raw_train_batch_reward , dim = 0 ).mean ().cpu ().item ()
453
+ raw_batch_format_acc_mean = torch .cat (self .raw_train_batch_format_acc , dim = 0 ).mean ().cpu ().item ()
454
+ raw_batch_ans_acc_mean = torch .cat (self .raw_train_batch_ans_acc , dim = 0 ).mean ().cpu ().item ()
455
+ raw_batch_response_len = torch .cat (self .raw_train_batch_response_len , dim = 0 )
456
+ raw_batch_response_len_mean = raw_batch_response_len .mean ().cpu ().item ()
457
+ overlength_samples_ratio = (raw_batch_response_len >= action_mask .size (- 1 )).to (float ).mean ().cpu ().item () # not an exact figure, but a close estimate
449
458
self .raw_train_batch_reward = []
450
459
self .raw_train_batch_format_acc = []
451
460
self .raw_train_batch_ans_acc = []
@@ -458,6 +467,7 @@ def _criterion(outputs, inputs):
458
467
f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
459
468
f"Response Length: { raw_batch_response_len_mean :.4f} " ,
460
469
f"Sample_utilization: { sample_utilization :.4f} " ,
470
+ f"Overlength samples ratio: { overlength_samples_ratio :.4f} " ,
461
471
] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
462
472
print ("\n " .join (to_log_msg ))
463
473
metrics = {
@@ -469,6 +479,7 @@ def _criterion(outputs, inputs):
469
479
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
470
480
"train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
471
481
"train/sample_utilization" : sample_utilization ,
482
+ "train/overlength_samples_ratio" : overlength_samples_ratio ,
472
483
"rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
473
484
}
474
485
if self .policy_loss_fn .beta > 0 :
0 commit comments