@@ -72,12 +72,12 @@ def __init__(
72
72
self .policy_model .gradient_checkpointing_enable ()
73
73
self .optimizer = HybridAdam (self .policy_model .parameters (), lr = grpo_config .get ("lr" , 1e-6 ))
74
74
self .accum_loss = torch .zeros (1 , device = self .device )
75
- self .accum_reward = torch .zeros (1 , device = self .device )
76
75
self .accum_kl = torch .zeros (1 , device = self .device )
77
- self .accum_format_acc = torch .zeros (1 , device = self .device )
78
- self .accum_ans_acc = torch .zeros (1 , device = self .device )
79
76
self .accum_advantages = torch .zeros (1 , device = self .device )
80
- self .accum_response_length = torch .zeros (1 , device = self .device )
77
+ self .raw_train_batch_reward = []
78
+ self .raw_train_batch_format_acc = []
79
+ self .raw_train_batch_ans_acc = []
80
+ self .raw_train_batch_response_len = []
81
81
self .accum_count = 0
82
82
self .generate_config = generate_config
83
83
self .grpo_config = grpo_config
@@ -186,7 +186,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
186
186
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
187
187
"""
188
188
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
189
- data = {k : v .view (- 1 , v .size (- 1 )) for k , v in kwargs .items ()}
189
+ data = {k : v .view (- 1 , v .size (- 1 )) for k , v in kwargs .items () if "raw_train_mini_batch_" not in k }
190
+ self .raw_train_batch_reward .extend (kwargs ["raw_train_mini_batch_reward" ])
191
+ self .raw_train_batch_format_acc .extend (kwargs ["raw_train_mini_batch_format_acc" ])
192
+ self .raw_train_batch_ans_acc .extend (kwargs ["raw_train_mini_batch_ans_acc" ])
193
+ self .raw_train_batch_response_len .extend (kwargs ["raw_train_mini_batch_response_len" ])
190
194
action_mask = data ["action_mask" ]
191
195
num_action = action_mask .shape [1 ]
192
196
old_action_log_probs = data ["action_log_probs" ]
@@ -430,11 +434,7 @@ def _criterion(outputs, inputs):
430
434
self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
431
435
if self .policy_loss_fn .beta > 0 :
432
436
self .accum_kl .add_ (sum (mean_kl ) / len (mean_kl ))
433
- self .accum_reward .add_ (reward .data )
434
- self .accum_format_acc .add_ (format_acc .data )
435
- self .accum_ans_acc .add_ (ans_acc .data )
436
437
self .accum_advantages .add_ (advantages .data )
437
- self .accum_response_length .add_ (response_length .data )
438
438
self .accum_count += 1
439
439
if need_update :
440
440
self .optimizer .step ()
@@ -452,21 +452,33 @@ def _criterion(outputs, inputs):
452
452
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
453
453
self .plugin .pp_size > 1 and self .booster .plugin .stage_manager .is_last_stage () and self .tp_rank == 0
454
454
):
455
+ raw_batch_reward_mean = sum (self .raw_train_batch_reward ) / len (self .raw_train_batch_reward )
456
+ raw_batch_format_acc_mean = sum (self .raw_train_batch_format_acc ) / len (
457
+ self .raw_train_batch_format_acc
458
+ )
459
+ raw_batch_ans_acc_mean = sum (self .raw_train_batch_ans_acc ) / len (self .raw_train_batch_ans_acc )
460
+ raw_batch_response_len_mean = sum (self .raw_train_batch_response_len ) / len (
461
+ self .raw_train_batch_response_len
462
+ )
463
+ self .raw_train_batch_reward = []
464
+ self .raw_train_batch_format_acc = []
465
+ self .raw_train_batch_ans_acc = []
466
+ self .raw_train_batch_response_len = []
455
467
to_log_msg = [
456
468
f"Loss: { self .accum_loss .item () / self .accum_count :.4f} " ,
457
- f"Reward: { self . accum_reward . item () / self . accum_count :.4f} " ,
458
- f"format Reward: { self . accum_format_acc . item () / self . accum_count :.4f} " ,
459
- f"Acc Reward: { self . accum_ans_acc . item () / self . accum_count :.4f} " ,
469
+ f"Reward: { raw_batch_reward_mean :.4f} " ,
470
+ f"format Reward: { raw_batch_format_acc_mean :.4f} " ,
471
+ f"Acc Reward: { raw_batch_ans_acc_mean :.4f} " ,
460
472
f"Advantages: { self .accum_advantages .item () / self .accum_count :.4f} " ,
461
- f"Response Length: { self . accum_response_length . item () / self . accum_count :.4f} " ,
473
+ f"Response Length: { raw_batch_response_len_mean :.4f} " ,
462
474
f"Sample_utilization: { sample_utilization :.4f} " ,
463
475
] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
464
476
print ("\n " .join (to_log_msg ))
465
477
metrics = {
466
- "metrics/reward" : self . accum_reward . item () / self . accum_count ,
467
- "metrics/format_acc" : self . accum_format_acc . item () / self . accum_count ,
468
- "metrics/ans_acc" : self . accum_ans_acc . item () / self . accum_count ,
469
- "metrics/response_length" : self . accum_response_length . item () / self . accum_count ,
478
+ "metrics/reward" : raw_batch_reward_mean ,
479
+ "metrics/format_acc" : raw_batch_format_acc_mean ,
480
+ "metrics/ans_acc" : raw_batch_ans_acc_mean ,
481
+ "metrics/response_length" : raw_batch_response_len_mean ,
470
482
"train/loss" : self .accum_loss .item () / self .accum_count ,
471
483
"train/advantages" : self .accum_advantages .item () / self .accum_count ,
472
484
"train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
@@ -478,12 +490,8 @@ def _criterion(outputs, inputs):
478
490
if self .wandb_run is not None :
479
491
self .wandb_run .log (metrics )
480
492
self .accum_loss .zero_ ()
481
- self .accum_reward .zero_ ()
482
- self .accum_ans_acc .zero_ ()
483
- self .accum_format_acc .zero_ ()
484
493
self .accum_kl .zero_ ()
485
494
self .accum_advantages .zero_ ()
486
- self .accum_response_length .zero_ ()
487
495
self .accum_count = 0
488
496
return loss_scalar
489
497
else :
0 commit comments