@@ -843,7 +843,10 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
843843 else :
844844 assert l_no_cots_len <= l_len
845845
846- # Apply temperature scaling to logprobs
846+ # IMPORTANT: Keep original logprobs for PPO training
847+ original_token_lps = token_lps .copy ()
848+
849+ # Apply temperature scaling ONLY for prior scores (MCTS root)
847850 if current_temperature != 1.0 :
848851 import torch
849852 token_lps_tensor = torch .tensor (token_lps , dtype = torch .float32 )
@@ -858,13 +861,21 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
858861
859862 # Renormalize to maintain valid probability distribution
860863 scaled_lps = torch .log_softmax (scaled_lps , dim = 0 )
861- token_lps = scaled_lps .tolist ()
862864
865+ # Use scaled logprobs for prior scores
866+ scaled_token_lps = scaled_lps .tolist ()
867+ else :
868+ # No temperature scaling
869+ scaled_token_lps = token_lps
870+
871+ # Prior scores: use temperature-scaled logprobs for MCTS exploration
863872 if self .llm_prior_with_cot :
864- scores .append (sum (token_lps ) if self .reduction == "sum" else sum (token_lps ) / l_len )
873+ scores .append (sum (scaled_token_lps ) if self .reduction == "sum" else sum (scaled_token_lps ) / l_len )
865874 else :
866- scores .append (sum (token_lps [- l_no_cots_len :]) if self .reduction == "sum" else sum (token_lps [- l_no_cots_len :]) / l_no_cots_len )
867- old_action_logprob .append (token_lps )
875+ scores .append (sum (scaled_token_lps [- l_no_cots_len :]) if self .reduction == "sum" else sum (scaled_token_lps [- l_no_cots_len :]) / l_no_cots_len )
876+
877+ # Old action logprob: use ORIGINAL logprobs for PPO training
878+ old_action_logprob .append (original_token_lps )
868879
869880 # Update temperature scheduler with average entropy
870881 if self .prior_temp_scheduler is not None and all_entropies :
0 commit comments