@@ -263,6 +263,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
263
263
input_ids_forward_micro_batch = data ["input_ids" ][
264
264
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
265
265
]
266
+ old_action_log_probs_micro_batch = old_action_log_probs [
267
+ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
268
+ ]
266
269
attention_mask_forward_micro_batch = data ["attention_mask" ][
267
270
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
268
271
]
@@ -319,17 +322,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
319
322
"action_mask" : action_mask_forward_micro_batch ,
320
323
"advantages" : advantages_forward_micro_batch ,
321
324
"loss_mask" : loss_mask_forward_micro_batch ,
325
+ "old_action_log_probs" : old_action_log_probs_micro_batch ,
322
326
"source" : self .rank ,
323
327
}
324
328
if reference_action_log_probs is not None :
325
329
data_policy_forward ["reference_action_log_probs" ] = reference_action_log_probs
326
330
327
331
kl = []
328
- policy_model_logits = torch .empty_like (input_ids_forward_micro_batch , device = self .device )
329
332
330
333
def _criterion (outputs , inputs ):
331
334
action_logits = outputs .logits
332
- policy_model_logits .copy_ (action_logits )
335
+ mini_batch_entropies .append (
336
+ (
337
+ ((entropy_from_logits (action_logits [:, - num_action :]) * inputs ["action_mask" ]).sum (- 1 ))
338
+ / inputs ["action_mask" ].sum (- 1 )
339
+ ).detach ()
340
+ )
333
341
action_log_probs = memory_efficient_logprob (
334
342
action_logits / self .generate_config ["temperature" ],
335
343
inputs ["input_ids" ],
@@ -352,7 +360,7 @@ def _criterion(outputs, inputs):
352
360
353
361
loss , _ = self .policy_loss_fn (
354
362
action_log_probs ,
355
- action_log_probs ,
363
+ inputs [ "old_action_log_probs" ] ,
356
364
inputs ["advantages" ].repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
357
365
per_token_kl ,
358
366
inputs ["action_mask" ],
@@ -376,20 +384,6 @@ def _criterion(outputs, inputs):
376
384
kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
377
385
mean_kl .append (kl )
378
386
mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
379
- mini_batch_entropies .append (
380
- all_reduce_mean (
381
- (
382
- (
383
- (
384
- entropy_from_logits (policy_model_logits [:, - num_action :])
385
- * action_mask_forward_micro_batch
386
- ).sum (- 1 )
387
- )
388
- / action_mask_forward_micro_batch .sum (- 1 )
389
- ).detach (),
390
- self .plugin ,
391
- )
392
- )
393
387
else :
394
388
policy_model_logits = self .policy_model (
395
389
input_ids = input_ids_forward_micro_batch ,
@@ -428,7 +422,7 @@ def _criterion(outputs, inputs):
428
422
429
423
loss , _ = self .policy_loss_fn (
430
424
action_log_probs ,
431
- old_action_log_probs ,
425
+ old_action_log_probs_micro_batch ,
432
426
advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
433
427
per_token_kl ,
434
428
action_mask_forward_micro_batch ,
@@ -468,7 +462,7 @@ def _criterion(outputs, inputs):
468
462
ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
469
463
advantages = all_reduce_mean (advantages .mean (), self .plugin )
470
464
response_length = all_reduce_mean (response_length .mean (), self .plugin )
471
- entropy = torch .cat (mini_batch_entropies , dim = 0 ).mean ()
465
+ entropy = all_reduce_mean ( torch .cat (mini_batch_entropies , dim = 0 ).mean (), self . plugin )
472
466
self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
473
467
self .accum_entropy .add_ (entropy .data )
474
468
if self .policy_loss_fn .beta > 0 :
0 commit comments