@@ -250,6 +250,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
250
250
input_ids_forward_micro_batch = data ["input_ids" ][
251
251
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
252
252
]
253
+ old_action_log_probs_micro_batch = old_action_log_probs [
254
+ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255
+ ]
253
256
attention_mask_forward_micro_batch = data ["attention_mask" ][
254
257
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255
258
]
@@ -306,17 +309,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
306
309
"action_mask" : action_mask_forward_micro_batch ,
307
310
"advantages" : advantages_forward_micro_batch ,
308
311
"loss_mask" : loss_mask_forward_micro_batch ,
312
+ "old_action_log_probs" : old_action_log_probs_micro_batch ,
309
313
"source" : self .rank ,
310
314
}
311
315
if reference_action_log_probs is not None :
312
316
data_policy_forward ["reference_action_log_probs" ] = reference_action_log_probs
313
317
314
318
kl = []
315
- policy_model_logits = torch .empty_like (input_ids_forward_micro_batch , device = self .device )
316
319
317
320
def _criterion (outputs , inputs ):
318
321
action_logits = outputs .logits
319
- policy_model_logits .copy_ (action_logits )
322
+ mini_batch_entropies .append (
323
+ (
324
+ ((entropy_from_logits (action_logits [:, - num_action :]) * inputs ["action_mask" ]).sum (- 1 ))
325
+ / inputs ["action_mask" ].sum (- 1 )
326
+ ).detach ()
327
+ )
320
328
action_log_probs = memory_efficient_logprob (
321
329
action_logits / self .generate_config ["temperature" ],
322
330
inputs ["input_ids" ],
@@ -339,7 +347,7 @@ def _criterion(outputs, inputs):
339
347
340
348
loss , _ = self .policy_loss_fn (
341
349
action_log_probs ,
342
- action_log_probs ,
350
+ inputs [ "old_action_log_probs" ] ,
343
351
inputs ["advantages" ].repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
344
352
per_token_kl ,
345
353
inputs ["action_mask" ],
@@ -363,20 +371,6 @@ def _criterion(outputs, inputs):
363
371
kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
364
372
mean_kl .append (kl )
365
373
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
366
- mini_batch_entropies .append (
367
- all_reduce_mean (
368
- (
369
- (
370
- (
371
- entropy_from_logits (policy_model_logits [:, - num_action :])
372
- * action_mask_forward_micro_batch
373
- ).sum (- 1 )
374
- )
375
- / action_mask_forward_micro_batch .sum (- 1 )
376
- ).detach (),
377
- self .plugin ,
378
- )
379
- )
380
374
else :
381
375
policy_model_logits = self .policy_model (
382
376
input_ids = input_ids_forward_micro_batch ,
@@ -415,7 +409,7 @@ def _criterion(outputs, inputs):
415
409
416
410
loss , _ = self .policy_loss_fn (
417
411
action_log_probs ,
418
- old_action_log_probs ,
412
+ old_action_log_probs_micro_batch ,
419
413
advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
420
414
per_token_kl ,
421
415
action_mask_forward_micro_batch ,
@@ -455,7 +449,7 @@ def _criterion(outputs, inputs):
455
449
ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
456
450
advantages = all_reduce_mean (advantages .mean (), self .plugin )
457
451
response_length = all_reduce_mean (response_length .mean (), self .plugin )
458
- entropy = torch .cat (mini_batch_entropies , dim = 0 ).mean ()
452
+ entropy = all_reduce_mean ( torch .cat (mini_batch_entropies , dim = 0 ).mean (), self . plugin )
459
453
self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
460
454
self .accum_entropy .add_ (entropy .data )
461
455
if self .policy_loss_fn .beta > 0 :
0 commit comments