Skip to content

Commit 9a20396

Browse files
committed
fix(pu): fix llm-prior-policy temperature and entropy bug
1 parent 684d82b commit 9a20396

File tree

2 files changed

+130
-76
lines changed

2 files changed

+130
-76
lines changed

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ class PriorZeroLLMConfig:
170170
reward_func: Optional[EasyDict] = field(default_factory=lambda: EasyDict({
171171
"format_reward": True,
172172
"format_param": EasyDict(
173-
{"format_weight": 0.1, }
173+
# {"format_weight": 0.1, }
174+
{"format_weight": 0.5, }
174175
),
175176
}))
176177
# advantage = target_value - pred_value
@@ -183,12 +184,28 @@ class PriorZeroLLMConfig:
183184
kl_estimator: str = "k3"
184185

185186
# Entropy loss for exploration bonus
186-
entropy_loss_coef: Optional[float] = 0.01 # None = disabled, typical values: 0.001-0.01
187-
# entropy_loss_coef: Optional[float] = None # None = disabled, typical values: 0.001-0.01
187+
# entropy_loss_coef: Optional[float] = 0.01 # None = disabled, typical values: 0.001-0.01
188+
entropy_loss_coef: Optional[float] = None # None = disabled, typical values: 0.001-0.01
188189

189190
# LLM Prior Mixing Configuration
191+
# ===== baseline root policy-head-logits =====
192+
# prior_mixing_cfg: Optional[EasyDict] = field(default_factory=lambda: EasyDict({
193+
# 'enable_soft_mixing': True, # Enable soft mixing instead of hard override
194+
# # 'mixing_alpha': 0.5, # Weight for LLM prior (0=network only, 1=LLM only)
195+
# 'mixing_alpha': 0., # Weight for LLM prior (0=network only, 1=LLM only)
196+
# 'alpha_schedule': None, # 'linear', 'cosine', 'exponential', or None (fixed)
197+
# # 'alpha_schedule': 'cosine', # Smooth decay
198+
# 'alpha_init': 0.8, # Initial alpha (high LLM influence)
199+
# 'alpha_final': 0.2, # Final alpha (low LLM influence)
200+
# 'alpha_decay_steps': 10000, # Steps to decay from init to final
201+
# 'enable_clip_prior': True, # Enable clipping of LLM prior probabilities
202+
# 'clip_prior_epsilon': 0.01, # Minimum probability for each action (exploration)
203+
# }))
204+
205+
# =====llm prior as root policy-logits =====
190206
prior_mixing_cfg: Optional[EasyDict] = field(default_factory=lambda: EasyDict({
191-
'enable_soft_mixing': True, # Enable soft mixing instead of hard override
207+
# 'enable_soft_mixing': True, # Enable soft mixing instead of hard override
208+
'enable_soft_mixing': False, # Enable soft mixing instead of hard override
192209
# 'mixing_alpha': 0.5, # Weight for LLM prior (0=network only, 1=LLM only)
193210
'mixing_alpha': 0., # Weight for LLM prior (0=network only, 1=LLM only)
194211
'alpha_schedule': None, # 'linear', 'cosine', 'exponential', or None (fixed)
@@ -488,9 +505,9 @@ def get_priorzero_config(
488505

489506
# Build exp_name
490507
exp_name = (
491-
f"data_priorzero/pz_{env_id}_{model_key}_"
508+
f"data_priorzero/0204/pz_{env_id}_{model_key}_"
492509
f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_"
493-
f"{entropy_str}_{mixing_str}_{clip_str}_seed{seed}"
510+
f"{entropy_str}_{mixing_str}_{clip_str}_frw05_seed{seed}"
494511
)
495512

496513
# Update config with generated exp_name

zoo/jericho/priorzero/priorzero_datafactory.py

Lines changed: 107 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,11 @@ def get_llm_prior(
717717
"""
718718
Get LLM prior scores for actions.
719719
720+
This function correctly applies temperature at the ACTION level (not token level):
721+
1. Compute raw action scores (sum of token logprobs)
722+
2. Apply temperature scaling across actions for the same state
723+
3. Compute policy entropy (action-level entropy, not token-level)
724+
720725
Args:
721726
states: List of current state observations
722727
valid_actions_list: List of valid actions for each state
@@ -750,19 +755,96 @@ def get_llm_prior(
750755
all_labels.append(action)
751756
all_prefix_cots.append(prefix)
752757

753-
scores, old_action_logprob = self._score_labels_with_prompt_logprobs(all_prompts, all_labels, all_prefix_cots)
754-
llm_prior_per_seq, llm_prior_per_tok, idx = [],[], 0
758+
# Get raw scores (sum of token logprobs, no temperature/softmax yet)
759+
raw_scores, old_action_logprob = self._score_labels_with_prompt_logprobs(all_prompts, all_labels, all_prefix_cots)
760+
761+
# Get current temperature from scheduler
762+
current_temperature = self.temperature # Default temperature
763+
if self.prior_temp_scheduler is not None:
764+
current_temperature = self.prior_temp_scheduler.get_temperature()
765+
766+
# Apply temperature and compute policy entropy at ACTION level
767+
llm_prior_per_seq, llm_prior_per_tok = [], []
768+
idx = 0
769+
all_policy_entropies = [] # Track policy entropies for adaptive scheduling
755770

756771
for prompt, actions, prefix in zip(prompt_list, valid_actions_list, prefix_cots):
757772
actions2 = actions if "go" in actions else (actions + ["go"])
773+
num_actions = len(actions2)
774+
775+
# Get scores for this state's actions
776+
state_raw_scores = raw_scores[idx : idx + num_actions]
777+
state_old_logprobs = old_action_logprob[idx : idx + num_actions]
778+
779+
# Apply temperature scaling at ACTION level
780+
if current_temperature != 1.0:
781+
state_scores_tensor = torch.tensor(state_raw_scores, dtype=torch.float32)
782+
783+
# Apply temperature: logits / T
784+
scaled_scores = state_scores_tensor / current_temperature
785+
786+
# Compute action probabilities via softmax
787+
# This is the CORRECT level: softmax over actions, not tokens!
788+
action_log_probs = torch.log_softmax(scaled_scores, dim=0)
789+
action_probs = torch.exp(action_log_probs)
790+
791+
# Compute POLICY ENTROPY: H(π(·|s)) = -Σ π(a|s) log π(a|s)
792+
# This measures the agent's uncertainty over action choices
793+
policy_entropy = -(action_probs * action_log_probs).sum().item()
794+
all_policy_entropies.append(policy_entropy)
795+
796+
# Use temperature-scaled scores for prior
797+
state_scores = action_log_probs.tolist()
798+
else:
799+
# No temperature scaling: use raw scores
800+
# Still need to normalize for proper probability distribution
801+
state_scores_tensor = torch.tensor(state_raw_scores, dtype=torch.float32)
802+
action_log_probs = torch.log_softmax(state_scores_tensor, dim=0)
803+
action_probs = torch.exp(action_log_probs)
804+
805+
policy_entropy = -(action_probs * action_log_probs).sum().item()
806+
all_policy_entropies.append(policy_entropy)
807+
808+
state_scores = action_log_probs.tolist()
809+
810+
# Build dictionaries for this state
758811
tmp_dict = {}
759812
tmp_dict2 = {}
760-
for action in actions2:
761-
tmp_dict[action] = scores[idx]
762-
tmp_dict2[action] = old_action_logprob[idx]
763-
idx = idx + 1
813+
for i, action in enumerate(actions2):
814+
tmp_dict[action] = state_scores[i] # Temperature-scaled log prob
815+
tmp_dict2[action] = state_old_logprobs[i] # Original token logprobs for PPO
816+
764817
llm_prior_per_seq.append(tmp_dict)
765818
llm_prior_per_tok.append(tmp_dict2)
819+
idx += num_actions
820+
821+
# Update temperature scheduler with average POLICY entropy
822+
if self.prior_temp_scheduler is not None and all_policy_entropies:
823+
avg_policy_entropy = sum(all_policy_entropies) / len(all_policy_entropies)
824+
new_temperature = self.prior_temp_scheduler.step(entropy=avg_policy_entropy)
825+
826+
# Log temperature and entropy statistics
827+
if self.rank == 0 and self.prior_temp_scheduler.current_step % 10 == 0:
828+
stats = self.prior_temp_scheduler.get_stats()
829+
print(
830+
f"[Prior Temperature] step={stats['temperature_step']} | "
831+
f"temp={stats['prior_temperature']:.3f} | "
832+
f"policy_entropy={avg_policy_entropy:.3f} | "
833+
f"progress={stats['temperature_progress']:.2%}"
834+
)
835+
836+
# Log to TensorBoard
837+
if self.tb_logger is not None:
838+
step = stats['temperature_step']
839+
self.tb_logger.add_scalar("prior_temp/temperature", stats['prior_temperature'], step)
840+
self.tb_logger.add_scalar("prior_temp/policy_entropy", avg_policy_entropy, step)
841+
self.tb_logger.add_scalar("prior_temp/progress", stats['temperature_progress'], step)
842+
if 'prior_avg_entropy' in stats:
843+
self.tb_logger.add_scalar("prior_temp/avg_entropy", stats['prior_avg_entropy'], step)
844+
if 'prior_target_entropy' in stats:
845+
self.tb_logger.add_scalar("prior_temp/target_entropy", stats['prior_target_entropy'], step)
846+
if 'prior_entropy_gap' in stats:
847+
self.tb_logger.add_scalar("prior_temp/entropy_gap", stats['prior_entropy_gap'], step)
766848

767849
if self.use_cot:
768850
self.episode_output.append({
@@ -777,13 +859,19 @@ def get_llm_prior(
777859
return llm_prior_per_seq, llm_prior_per_tok
778860

779861
@torch.no_grad()
780-
def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels: List[str], all_prefix_cots: List[str]) -> List[float]:
781-
assert len(all_prompts) == len(all_labels) == len(all_prefix_cots)
862+
def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels: List[str], all_prefix_cots: List[str]) -> Tuple[List[float], List[List[float]]]:
863+
"""
864+
Compute raw log probabilities for action sequences.
782865
783-
# Get current temperature from scheduler
784-
current_temperature = self.temperature # Default temperature
785-
if self.prior_temp_scheduler is not None:
786-
current_temperature = self.prior_temp_scheduler.get_temperature()
866+
This function computes the original sequence log probabilities by summing token-level logprobs.
867+
It does NOT apply temperature scaling or softmax normalization at the token level.
868+
Temperature and softmax should be applied at the ACTION level in get_llm_prior.
869+
870+
Returns:
871+
scores: List of raw log probabilities (sum of token logprobs)
872+
old_action_logprob: List of token-level logprobs for PPO training
873+
"""
874+
assert len(all_prompts) == len(all_labels) == len(all_prefix_cots)
787875

788876
sampling_params = SamplingParams(
789877
temperature=self.temperature, # Keep original for generation
@@ -817,7 +905,6 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
817905

818906
scores = []
819907
old_action_logprob = []
820-
all_entropies = [] # Track entropies for adaptive scheduling
821908

822909
for out, ids, p_len, l_len, l_no_cots_len in zip(outs, full_ids, p_lens, l_lens, l_no_cots_lens):
823910
prompt_logprobs = getattr(out, "prompt_logprobs", None)
@@ -843,67 +930,17 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
843930
else:
844931
assert l_no_cots_len <= l_len
845932

846-
# IMPORTANT: Keep original logprobs for PPO training
847-
original_token_lps = token_lps.copy()
848-
849-
# Apply temperature scaling ONLY for prior scores (MCTS root)
850-
if current_temperature != 1.0:
851-
import torch
852-
token_lps_tensor = torch.tensor(token_lps, dtype=torch.float32)
853-
854-
# Apply temperature scaling
855-
scaled_lps = token_lps_tensor / current_temperature
856-
857-
# Compute entropy before normalization (for adaptive scheduling)
858-
probs = torch.exp(scaled_lps)
859-
entropy = -(probs * scaled_lps).sum().item()
860-
all_entropies.append(entropy)
861-
862-
# Renormalize to maintain valid probability distribution
863-
scaled_lps = torch.log_softmax(scaled_lps, dim=0)
864-
865-
# Use scaled logprobs for prior scores
866-
scaled_token_lps = scaled_lps.tolist()
867-
else:
868-
# No temperature scaling
869-
scaled_token_lps = token_lps
870-
871-
# Prior scores: use temperature-scaled logprobs for MCTS exploration
933+
# Compute raw sequence log probability (sum of token logprobs)
934+
# This is the correct action-level score: log P(action) = sum_i log P(token_i)
872935
if self.llm_prior_with_cot:
873-
scores.append(sum(scaled_token_lps) if self.reduction == "sum" else sum(scaled_token_lps) / l_len)
936+
raw_score = sum(token_lps) if self.reduction == "sum" else sum(token_lps) / l_len
874937
else:
875-
scores.append(sum(scaled_token_lps[-l_no_cots_len:]) if self.reduction == "sum" else sum(scaled_token_lps[-l_no_cots_len:]) / l_no_cots_len)
938+
raw_score = sum(token_lps[-l_no_cots_len:]) if self.reduction == "sum" else sum(token_lps[-l_no_cots_len:]) / l_no_cots_len
876939

877-
# Old action logprob: use ORIGINAL logprobs for PPO training
878-
old_action_logprob.append(original_token_lps)
940+
scores.append(raw_score)
879941

880-
# Update temperature scheduler with average entropy
881-
if self.prior_temp_scheduler is not None and all_entropies:
882-
avg_entropy = sum(all_entropies) / len(all_entropies)
883-
new_temperature = self.prior_temp_scheduler.step(entropy=avg_entropy)
884-
885-
# Log temperature and entropy statistics
886-
if self.rank == 0 and self.prior_temp_scheduler.current_step % 10 == 0:
887-
stats = self.prior_temp_scheduler.get_stats()
888-
print(
889-
f"[Prior Temperature] step={stats['temperature_step']} | "
890-
f"temp={stats['prior_temperature']:.3f} | "
891-
f"entropy={avg_entropy:.3f} | "
892-
f"progress={stats['temperature_progress']:.2%}"
893-
)
894-
895-
# Log to TensorBoard
896-
if self.tb_logger is not None:
897-
step = stats['temperature_step']
898-
self.tb_logger.add_scalar("prior_temp/temperature", stats['prior_temperature'], step)
899-
self.tb_logger.add_scalar("prior_temp/entropy", avg_entropy, step)
900-
self.tb_logger.add_scalar("prior_temp/progress", stats['temperature_progress'], step)
901-
if 'prior_avg_entropy' in stats:
902-
self.tb_logger.add_scalar("prior_temp/avg_entropy", stats['prior_avg_entropy'], step)
903-
if 'prior_target_entropy' in stats:
904-
self.tb_logger.add_scalar("prior_temp/target_entropy", stats['prior_target_entropy'], step)
905-
if 'prior_entropy_gap' in stats:
906-
self.tb_logger.add_scalar("prior_temp/entropy_gap", stats['prior_entropy_gap'], step)
942+
# Keep original token-level logprobs for PPO training
943+
old_action_logprob.append(token_lps)
907944

908945
return scores, old_action_logprob
909946

0 commit comments

Comments
 (0)