Skip to content

Commit ebbc9d8

Browse files
committed
fix(pu): fix prior_temp_schedule overwrite logprob bug
1 parent 3c13ec1 commit ebbc9d8

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

zoo/jericho/priorzero/priorzero_datafactory.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,10 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
843843
else:
844844
assert l_no_cots_len <= l_len
845845

846-
# Apply temperature scaling to logprobs
846+
# IMPORTANT: Keep original logprobs for PPO training
847+
original_token_lps = token_lps.copy()
848+
849+
# Apply temperature scaling ONLY for prior scores (MCTS root)
847850
if current_temperature != 1.0:
848851
import torch
849852
token_lps_tensor = torch.tensor(token_lps, dtype=torch.float32)
@@ -858,13 +861,21 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
858861

859862
# Renormalize to maintain valid probability distribution
860863
scaled_lps = torch.log_softmax(scaled_lps, dim=0)
861-
token_lps = scaled_lps.tolist()
862864

865+
# Use scaled logprobs for prior scores
866+
scaled_token_lps = scaled_lps.tolist()
867+
else:
868+
# No temperature scaling
869+
scaled_token_lps = token_lps
870+
871+
# Prior scores: use temperature-scaled logprobs for MCTS exploration
863872
if self.llm_prior_with_cot:
864-
scores.append(sum(token_lps) if self.reduction == "sum" else sum(token_lps) / l_len)
873+
scores.append(sum(scaled_token_lps) if self.reduction == "sum" else sum(scaled_token_lps) / l_len)
865874
else:
866-
scores.append(sum(token_lps[-l_no_cots_len:]) if self.reduction == "sum" else sum(token_lps[-l_no_cots_len:]) / l_no_cots_len)
867-
old_action_logprob.append(token_lps)
875+
scores.append(sum(scaled_token_lps[-l_no_cots_len:]) if self.reduction == "sum" else sum(scaled_token_lps[-l_no_cots_len:]) / l_no_cots_len)
876+
877+
# Old action logprob: use ORIGINAL logprobs for PPO training
878+
old_action_logprob.append(original_token_lps)
868879

869880
# Update temperature scheduler with average entropy
870881
if self.prior_temp_scheduler is not None and all_entropies:

0 commit comments

Comments
 (0)