@@ -101,6 +101,7 @@ def __init__(
101
101
clip_eps_high = grpo_config .get ("clip_eps_high" , 0.2 ),
102
102
beta = grpo_config .get ("beta" , 0.01 ),
103
103
loss_variation = grpo_config .get ("loss_variation" , "sample_level" ),
104
+ adv = grpo_config .get ("algo" ),
104
105
)
105
106
106
107
# Reference model is initialized from policy model.
@@ -137,6 +138,8 @@ def __init__(
137
138
eta_min = 0.1 * grpo_config .get ("lr" , 1e-6 ),
138
139
)
139
140
141
+ self .adv = grpo_config .get ("algo" )
142
+
140
143
def setup (self ):
141
144
super ().setup ()
142
145
if (not self .plugin .pp_size > 1 and self .rank == 0 ) or (
@@ -204,9 +207,23 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
204
207
# [minibatch_size x num_generations]
205
208
reward_mean = reward_mean .repeat_interleave (self .num_generations , dim = 0 )
206
209
207
- reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
208
- # [minibatch_size x num_generations]
209
- advantages = ((reward - reward_mean ) / (reward_std + 1e-4 )).unsqueeze (dim = - 1 )
210
+ if self .adv == "GRPO" or self .adv == "DAPO" :
211
+
212
+ reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
213
+ # [minibatch_size x num_generations]
214
+ advantages = ((reward - reward_mean ) / (reward_std + 1e-4 )).unsqueeze (dim = - 1 )
215
+
216
+ elif self .adv == "REINFORCE_PPB" :
217
+
218
+ # [minibatch_size x num_generations]
219
+ advantages = ((reward - reward_mean )).unsqueeze (dim = - 1 )
220
+
221
+ elif self .adv == "RLOO" :
222
+
223
+ advantages = (
224
+ reward * self .num_generations / (self .num_generations - 1 )
225
+ - reward_mean * self .num_generations / (self .num_generations - 1 )
226
+ ).unsqueeze (dim = - 1 )
210
227
211
228
# [minibatch_size x num_of_generation]
212
229
loss_mask = torch .ones (action_mask .size (0 ), device = action_mask .device ).bool ()
@@ -358,10 +375,34 @@ def _criterion(outputs, inputs):
358
375
per_token_kl = 0.0
359
376
kl .append (torch .tensor (0.0 ))
360
377
378
+ inputs ["advantages" ].repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 )
379
+
380
+ if self .adv == "REINFORCE_PPB" :
381
+
382
+ inputs ["advantages" ] = inputs ["advantages" ] - self .policy_loss_fn .beta * per_token_kl
383
+ advantages_forward_micro_batch_mean = torch .sum (
384
+ inputs ["advantages" ] * inputs ["action_mask" ]
385
+ ) / (torch .sum (inputs ["action_mask" ]) + 1e-4 )
386
+ advantages_forward_micro_batch_std = torch .rsqrt (
387
+ torch .sum (
388
+ (inputs ["advantages" ] - advantages_forward_micro_batch_mean ) ** 2
389
+ * inputs ["action_mask" ]
390
+ )
391
+ / (torch .sum (inputs ["action_mask" ]) + 1e-4 )
392
+ + 1e-8
393
+ )
394
+ inputs ["advantages" ] = (
395
+ (inputs ["advantages" ] - advantages_forward_micro_batch_mean )
396
+ * inputs ["action_mask" ]
397
+ / (advantages_forward_micro_batch_std )
398
+ )
399
+
400
+ per_token_kl = 0.0
401
+
361
402
loss , _ = self .policy_loss_fn (
362
403
action_log_probs ,
363
404
inputs ["old_action_log_probs" ],
364
- inputs ["advantages" ]. repeat_interleave ( action_log_probs . size ( - 1 ), dim = - 1 ) ,
405
+ inputs ["advantages" ],
365
406
per_token_kl ,
366
407
inputs ["action_mask" ],
367
408
loss_mask = inputs ["loss_mask" ],
@@ -420,10 +461,39 @@ def _criterion(outputs, inputs):
420
461
per_token_kl = 0.0
421
462
kl = None
422
463
464
+ (
465
+ advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 )
466
+ - self .policy_loss_fn .beta * per_token_kl
467
+ )
468
+
469
+ if self .adv == "REINFORCE_PPB" :
470
+
471
+ advantages_forward_micro_batch = (
472
+ advantages_forward_micro_batch - self .policy_loss_fn .beta * per_token_kl
473
+ )
474
+ advantages_forward_micro_batch_mean = torch .sum (
475
+ advantages_forward_micro_batch * action_mask_forward_micro_batch
476
+ ) / (torch .sum (action_mask_forward_micro_batch ) + 1e-4 )
477
+ advantages_forward_micro_batch_std = torch .rsqrt (
478
+ torch .sum (
479
+ (advantages_forward_micro_batch - advantages_forward_micro_batch_mean ) ** 2
480
+ * action_mask_forward_micro_batch
481
+ )
482
+ / (torch .sum (action_mask_forward_micro_batch ) + 1e-4 )
483
+ + 1e-8
484
+ )
485
+ advantages_forward_micro_batch = (
486
+ (advantages_forward_micro_batch - advantages_forward_micro_batch_mean )
487
+ * action_mask_forward_micro_batch
488
+ / (advantages_forward_micro_batch_std )
489
+ )
490
+
491
+ per_token_kl = 0.0
492
+
423
493
loss , _ = self .policy_loss_fn (
424
494
action_log_probs ,
425
495
old_action_log_probs_micro_batch ,
426
- advantages_forward_micro_batch . repeat_interleave ( action_log_probs . size ( - 1 ), dim = - 1 ) ,
496
+ advantages_forward_micro_batch ,
427
497
per_token_kl ,
428
498
action_mask_forward_micro_batch ,
429
499
loss_mask = loss_mask_forward_micro_batch ,
0 commit comments