@@ -263,6 +263,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
263263 input_ids_forward_micro_batch = data ["input_ids" ][
264264 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
265265 ]
266+ old_action_log_probs_micro_batch = old_action_log_probs [
267+ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
268+ ]
266269 attention_mask_forward_micro_batch = data ["attention_mask" ][
267270 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
268271 ]
@@ -319,17 +322,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
319322 "action_mask" : action_mask_forward_micro_batch ,
320323 "advantages" : advantages_forward_micro_batch ,
321324 "loss_mask" : loss_mask_forward_micro_batch ,
325+ "old_action_log_probs" : old_action_log_probs_micro_batch ,
322326 "source" : self .rank ,
323327 }
324328 if reference_action_log_probs is not None :
325329 data_policy_forward ["reference_action_log_probs" ] = reference_action_log_probs
326330
327331 kl = []
328- policy_model_logits = torch .empty_like (input_ids_forward_micro_batch , device = self .device )
329332
330333 def _criterion (outputs , inputs ):
331334 action_logits = outputs .logits
332- policy_model_logits .copy_ (action_logits )
335+ mini_batch_entropies .append (
336+ (
337+ ((entropy_from_logits (action_logits [:, - num_action :]) * inputs ["action_mask" ]).sum (- 1 ))
338+ / inputs ["action_mask" ].sum (- 1 )
339+ ).detach ()
340+ )
333341 action_log_probs = memory_efficient_logprob (
334342 action_logits / self .generate_config ["temperature" ],
335343 inputs ["input_ids" ],
@@ -352,7 +360,7 @@ def _criterion(outputs, inputs):
352360
353361 loss , _ = self .policy_loss_fn (
354362 action_log_probs ,
355- action_log_probs ,
363+ inputs [ "old_action_log_probs" ] ,
356364 inputs ["advantages" ].repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
357365 per_token_kl ,
358366 inputs ["action_mask" ],
@@ -376,20 +384,6 @@ def _criterion(outputs, inputs):
376384 kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
377385 mean_kl .append (kl )
378386 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
379- mini_batch_entropies .append (
380- all_reduce_mean (
381- (
382- (
383- (
384- entropy_from_logits (policy_model_logits [:, - num_action :])
385- * action_mask_forward_micro_batch
386- ).sum (- 1 )
387- )
388- / action_mask_forward_micro_batch .sum (- 1 )
389- ).detach (),
390- self .plugin ,
391- )
392- )
393387 else :
394388 policy_model_logits = self .policy_model (
395389 input_ids = input_ids_forward_micro_batch ,
@@ -428,7 +422,7 @@ def _criterion(outputs, inputs):
428422
429423 loss , _ = self .policy_loss_fn (
430424 action_log_probs ,
431- old_action_log_probs ,
425+ old_action_log_probs_micro_batch ,
432426 advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
433427 per_token_kl ,
434428 action_mask_forward_micro_batch ,
@@ -468,7 +462,7 @@ def _criterion(outputs, inputs):
468462 ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
469463 advantages = all_reduce_mean (advantages .mean (), self .plugin )
470464 response_length = all_reduce_mean (response_length .mean (), self .plugin )
471- entropy = torch .cat (mini_batch_entropies , dim = 0 ).mean ()
465+ entropy = all_reduce_mean ( torch .cat (mini_batch_entropies , dim = 0 ).mean (), self . plugin )
472466 self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
473467 self .accum_entropy .add_ (entropy .data )
474468 if self .policy_loss_fn .beta > 0 :
0 commit comments