@@ -717,6 +717,11 @@ def get_llm_prior(
717717 """
718718 Get LLM prior scores for actions.
719719
720+ This function correctly applies temperature at the ACTION level (not token level):
721+ 1. Compute raw action scores (sum of token logprobs)
722+ 2. Apply temperature scaling across actions for the same state
723+ 3. Compute policy entropy (action-level entropy, not token-level)
724+
720725 Args:
721726 states: List of current state observations
722727 valid_actions_list: List of valid actions for each state
@@ -750,19 +755,96 @@ def get_llm_prior(
750755 all_labels .append (action )
751756 all_prefix_cots .append (prefix )
752757
753- scores , old_action_logprob = self ._score_labels_with_prompt_logprobs (all_prompts , all_labels , all_prefix_cots )
754- llm_prior_per_seq , llm_prior_per_tok , idx = [],[], 0
758+ # Get raw scores (sum of token logprobs, no temperature/softmax yet)
759+ raw_scores , old_action_logprob = self ._score_labels_with_prompt_logprobs (all_prompts , all_labels , all_prefix_cots )
760+
761+ # Get current temperature from scheduler
762+ current_temperature = self .temperature # Default temperature
763+ if self .prior_temp_scheduler is not None :
764+ current_temperature = self .prior_temp_scheduler .get_temperature ()
765+
766+ # Apply temperature and compute policy entropy at ACTION level
767+ llm_prior_per_seq , llm_prior_per_tok = [], []
768+ idx = 0
769+ all_policy_entropies = [] # Track policy entropies for adaptive scheduling
755770
756771 for prompt , actions , prefix in zip (prompt_list , valid_actions_list , prefix_cots ):
757772 actions2 = actions if "go" in actions else (actions + ["go" ])
773+ num_actions = len (actions2 )
774+
775+ # Get scores for this state's actions
776+ state_raw_scores = raw_scores [idx : idx + num_actions ]
777+ state_old_logprobs = old_action_logprob [idx : idx + num_actions ]
778+
779+ # Apply temperature scaling at ACTION level
780+ if current_temperature != 1.0 :
781+ state_scores_tensor = torch .tensor (state_raw_scores , dtype = torch .float32 )
782+
783+ # Apply temperature: logits / T
784+ scaled_scores = state_scores_tensor / current_temperature
785+
786+ # Compute action probabilities via softmax
787+ # This is the CORRECT level: softmax over actions, not tokens!
788+ action_log_probs = torch .log_softmax (scaled_scores , dim = 0 )
789+ action_probs = torch .exp (action_log_probs )
790+
791+ # Compute POLICY ENTROPY: H(π(·|s)) = -Σ π(a|s) log π(a|s)
792+ # This measures the agent's uncertainty over action choices
793+ policy_entropy = - (action_probs * action_log_probs ).sum ().item ()
794+ all_policy_entropies .append (policy_entropy )
795+
796+ # Use temperature-scaled scores for prior
797+ state_scores = action_log_probs .tolist ()
798+ else :
799+ # No temperature scaling: use raw scores
800+ # Still need to normalize for proper probability distribution
801+ state_scores_tensor = torch .tensor (state_raw_scores , dtype = torch .float32 )
802+ action_log_probs = torch .log_softmax (state_scores_tensor , dim = 0 )
803+ action_probs = torch .exp (action_log_probs )
804+
805+ policy_entropy = - (action_probs * action_log_probs ).sum ().item ()
806+ all_policy_entropies .append (policy_entropy )
807+
808+ state_scores = action_log_probs .tolist ()
809+
810+ # Build dictionaries for this state
758811 tmp_dict = {}
759812 tmp_dict2 = {}
760- for action in actions2 :
761- tmp_dict [action ] = scores [ idx ]
762- tmp_dict2 [action ] = old_action_logprob [ idx ]
763- idx = idx + 1
813+ for i , action in enumerate ( actions2 ) :
814+ tmp_dict [action ] = state_scores [ i ] # Temperature-scaled log prob
815+ tmp_dict2 [action ] = state_old_logprobs [ i ] # Original token logprobs for PPO
816+
764817 llm_prior_per_seq .append (tmp_dict )
765818 llm_prior_per_tok .append (tmp_dict2 )
819+ idx += num_actions
820+
821+ # Update temperature scheduler with average POLICY entropy
822+ if self .prior_temp_scheduler is not None and all_policy_entropies :
823+ avg_policy_entropy = sum (all_policy_entropies ) / len (all_policy_entropies )
824+ new_temperature = self .prior_temp_scheduler .step (entropy = avg_policy_entropy )
825+
826+ # Log temperature and entropy statistics
827+ if self .rank == 0 and self .prior_temp_scheduler .current_step % 10 == 0 :
828+ stats = self .prior_temp_scheduler .get_stats ()
829+ print (
830+ f"[Prior Temperature] step={ stats ['temperature_step' ]} | "
831+ f"temp={ stats ['prior_temperature' ]:.3f} | "
832+ f"policy_entropy={ avg_policy_entropy :.3f} | "
833+ f"progress={ stats ['temperature_progress' ]:.2%} "
834+ )
835+
836+ # Log to TensorBoard
837+ if self .tb_logger is not None :
838+ step = stats ['temperature_step' ]
839+ self .tb_logger .add_scalar ("prior_temp/temperature" , stats ['prior_temperature' ], step )
840+ self .tb_logger .add_scalar ("prior_temp/policy_entropy" , avg_policy_entropy , step )
841+ self .tb_logger .add_scalar ("prior_temp/progress" , stats ['temperature_progress' ], step )
842+ if 'prior_avg_entropy' in stats :
843+ self .tb_logger .add_scalar ("prior_temp/avg_entropy" , stats ['prior_avg_entropy' ], step )
844+ if 'prior_target_entropy' in stats :
845+ self .tb_logger .add_scalar ("prior_temp/target_entropy" , stats ['prior_target_entropy' ], step )
846+ if 'prior_entropy_gap' in stats :
847+ self .tb_logger .add_scalar ("prior_temp/entropy_gap" , stats ['prior_entropy_gap' ], step )
766848
767849 if self .use_cot :
768850 self .episode_output .append ({
@@ -777,13 +859,19 @@ def get_llm_prior(
777859 return llm_prior_per_seq , llm_prior_per_tok
778860
779861 @torch .no_grad ()
780- def _score_labels_with_prompt_logprobs (self , all_prompts : List [str ], all_labels : List [str ], all_prefix_cots : List [str ]) -> List [float ]:
781- assert len (all_prompts ) == len (all_labels ) == len (all_prefix_cots )
862+ def _score_labels_with_prompt_logprobs (self , all_prompts : List [str ], all_labels : List [str ], all_prefix_cots : List [str ]) -> Tuple [List [float ], List [List [float ]]]:
863+ """
864+ Compute raw log probabilities for action sequences.
782865
783- # Get current temperature from scheduler
784- current_temperature = self .temperature # Default temperature
785- if self .prior_temp_scheduler is not None :
786- current_temperature = self .prior_temp_scheduler .get_temperature ()
866+ This function computes the original sequence log probabilities by summing token-level logprobs.
867+ It does NOT apply temperature scaling or softmax normalization at the token level.
868+ Temperature and softmax should be applied at the ACTION level in get_llm_prior.
869+
870+ Returns:
871+ scores: List of raw log probabilities (sum of token logprobs)
872+ old_action_logprob: List of token-level logprobs for PPO training
873+ """
874+ assert len (all_prompts ) == len (all_labels ) == len (all_prefix_cots )
787875
788876 sampling_params = SamplingParams (
789877 temperature = self .temperature , # Keep original for generation
@@ -817,7 +905,6 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
817905
818906 scores = []
819907 old_action_logprob = []
820- all_entropies = [] # Track entropies for adaptive scheduling
821908
822909 for out , ids , p_len , l_len , l_no_cots_len in zip (outs , full_ids , p_lens , l_lens , l_no_cots_lens ):
823910 prompt_logprobs = getattr (out , "prompt_logprobs" , None )
@@ -843,67 +930,17 @@ def _score_labels_with_prompt_logprobs(self, all_prompts: List[str], all_labels:
843930 else :
844931 assert l_no_cots_len <= l_len
845932
846- # IMPORTANT: Keep original logprobs for PPO training
847- original_token_lps = token_lps .copy ()
848-
849- # Apply temperature scaling ONLY for prior scores (MCTS root)
850- if current_temperature != 1.0 :
851- import torch
852- token_lps_tensor = torch .tensor (token_lps , dtype = torch .float32 )
853-
854- # Apply temperature scaling
855- scaled_lps = token_lps_tensor / current_temperature
856-
857- # Compute entropy before normalization (for adaptive scheduling)
858- probs = torch .exp (scaled_lps )
859- entropy = - (probs * scaled_lps ).sum ().item ()
860- all_entropies .append (entropy )
861-
862- # Renormalize to maintain valid probability distribution
863- scaled_lps = torch .log_softmax (scaled_lps , dim = 0 )
864-
865- # Use scaled logprobs for prior scores
866- scaled_token_lps = scaled_lps .tolist ()
867- else :
868- # No temperature scaling
869- scaled_token_lps = token_lps
870-
871- # Prior scores: use temperature-scaled logprobs for MCTS exploration
933+ # Compute raw sequence log probability (sum of token logprobs)
934+ # This is the correct action-level score: log P(action) = sum_i log P(token_i)
872935 if self .llm_prior_with_cot :
873- scores . append ( sum (scaled_token_lps ) if self .reduction == "sum" else sum (scaled_token_lps ) / l_len )
936+ raw_score = sum (token_lps ) if self .reduction == "sum" else sum (token_lps ) / l_len
874937 else :
875- scores . append ( sum (scaled_token_lps [- l_no_cots_len :]) if self .reduction == "sum" else sum (scaled_token_lps [- l_no_cots_len :]) / l_no_cots_len )
938+ raw_score = sum (token_lps [- l_no_cots_len :]) if self .reduction == "sum" else sum (token_lps [- l_no_cots_len :]) / l_no_cots_len
876939
877- # Old action logprob: use ORIGINAL logprobs for PPO training
878- old_action_logprob .append (original_token_lps )
940+ scores .append (raw_score )
879941
880- # Update temperature scheduler with average entropy
881- if self .prior_temp_scheduler is not None and all_entropies :
882- avg_entropy = sum (all_entropies ) / len (all_entropies )
883- new_temperature = self .prior_temp_scheduler .step (entropy = avg_entropy )
884-
885- # Log temperature and entropy statistics
886- if self .rank == 0 and self .prior_temp_scheduler .current_step % 10 == 0 :
887- stats = self .prior_temp_scheduler .get_stats ()
888- print (
889- f"[Prior Temperature] step={ stats ['temperature_step' ]} | "
890- f"temp={ stats ['prior_temperature' ]:.3f} | "
891- f"entropy={ avg_entropy :.3f} | "
892- f"progress={ stats ['temperature_progress' ]:.2%} "
893- )
894-
895- # Log to TensorBoard
896- if self .tb_logger is not None :
897- step = stats ['temperature_step' ]
898- self .tb_logger .add_scalar ("prior_temp/temperature" , stats ['prior_temperature' ], step )
899- self .tb_logger .add_scalar ("prior_temp/entropy" , avg_entropy , step )
900- self .tb_logger .add_scalar ("prior_temp/progress" , stats ['temperature_progress' ], step )
901- if 'prior_avg_entropy' in stats :
902- self .tb_logger .add_scalar ("prior_temp/avg_entropy" , stats ['prior_avg_entropy' ], step )
903- if 'prior_target_entropy' in stats :
904- self .tb_logger .add_scalar ("prior_temp/target_entropy" , stats ['prior_target_entropy' ], step )
905- if 'prior_entropy_gap' in stats :
906- self .tb_logger .add_scalar ("prior_temp/entropy_gap" , stats ['prior_entropy_gap' ], step )
942+ # Keep original token-level logprobs for PPO training
943+ old_action_logprob .append (token_lps )
907944
908945 return scores , old_action_logprob
909946
0 commit comments