@@ -281,7 +281,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
281
281
282
282
if self .booster .plugin .stage_manager .is_last_stage ():
283
283
reference_action_log_probs = memory_efficient_logprob (
284
- reference_model_outputs ["outputs" ]["logits" ],
284
+ reference_model_outputs ["outputs" ]["logits" ] / self . generate_config [ "temperature" ] ,
285
285
input_ids_forward_micro_batch ,
286
286
num_action ,
287
287
shard_config = self .plugin .shard_config ,
@@ -308,7 +308,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308
308
def _criterion (outputs , inputs ):
309
309
action_logits = outputs .logits
310
310
action_log_probs = memory_efficient_logprob (
311
- action_logits ,
311
+ action_logits / self . generate_config [ "temperature" ] ,
312
312
inputs ["input_ids" ],
313
313
num_action ,
314
314
shard_config = self .plugin .shard_config ,
@@ -375,7 +375,7 @@ def _criterion(outputs, inputs):
375
375
reference_model_logits / self .generate_config ["temperature" ],
376
376
input_ids_forward_micro_batch ,
377
377
num_action ,
378
- self .plugin .shard_config ,
378
+ shard_config = self .plugin .shard_config ,
379
379
)
380
380
per_token_kl = (
381
381
torch .exp (reference_action_log_probs - action_log_probs )
0 commit comments