@@ -281,7 +281,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
281281
282282 if self .booster .plugin .stage_manager .is_last_stage ():
283283 reference_action_log_probs = memory_efficient_logprob (
284- reference_model_outputs ["outputs" ]["logits" ],
284+ reference_model_outputs ["outputs" ]["logits" ] / self . generate_config [ "temperature" ] ,
285285 input_ids_forward_micro_batch ,
286286 num_action ,
287287 shard_config = self .plugin .shard_config ,
@@ -308,7 +308,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308308 def _criterion (outputs , inputs ):
309309 action_logits = outputs .logits
310310 action_log_probs = memory_efficient_logprob (
311- action_logits ,
311+ action_logits / self . generate_config [ "temperature" ] ,
312312 inputs ["input_ids" ],
313313 num_action ,
314314 shard_config = self .plugin .shard_config ,
@@ -375,7 +375,7 @@ def _criterion(outputs, inputs):
375375 reference_model_logits / self .generate_config ["temperature" ],
376376 input_ids_forward_micro_batch ,
377377 num_action ,
378- self .plugin .shard_config ,
378+ shard_config = self .plugin .shard_config ,
379379 )
380380 per_token_kl = (
381381 torch .exp (reference_action_log_probs - action_log_probs )
0 commit comments