@@ -294,7 +294,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
294
294
295
295
if self .booster .plugin .stage_manager .is_last_stage ():
296
296
reference_action_log_probs = memory_efficient_logprob (
297
- reference_model_outputs ["outputs" ]["logits" ],
297
+ reference_model_outputs ["outputs" ]["logits" ] / self . generate_config [ "temperature" ] ,
298
298
input_ids_forward_micro_batch ,
299
299
num_action ,
300
300
shard_config = self .plugin .shard_config ,
@@ -321,7 +321,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
321
321
def _criterion (outputs , inputs ):
322
322
action_logits = outputs .logits
323
323
action_log_probs = memory_efficient_logprob (
324
- action_logits ,
324
+ action_logits / self . generate_config [ "temperature" ] ,
325
325
inputs ["input_ids" ],
326
326
num_action ,
327
327
shard_config = self .plugin .shard_config ,
@@ -388,7 +388,7 @@ def _criterion(outputs, inputs):
388
388
reference_model_logits / self .generate_config ["temperature" ],
389
389
input_ids_forward_micro_batch ,
390
390
num_action ,
391
- self .plugin .shard_config ,
391
+ shard_config = self .plugin .shard_config ,
392
392
)
393
393
per_token_kl = (
394
394
torch .exp (reference_action_log_probs - action_log_probs )
0 commit comments