2626
2727@POLICY_REGISTRY .register ('priorzero' , force_overwrite = True )
2828class PriorZeroPolicy (OriginalUniZeroPolicy ):
29- def __init__ (self , cfg : Dict , model : torch .nn .Module = None , enable_field : List [str ] = None , ** kwargs ):
29+ def __init__ (self , cfg : Dict , model : torch .nn .Module = None , enable_field : List [str ] = None , ** kwargs ):
3030 super ().__init__ (cfg , model , enable_field )
31+ self ._mixing_step = 0 # Track steps for alpha scheduling
3132
3233 def _init_learn (self ) -> None :
3334 super ()._init_learn ()
@@ -283,7 +284,76 @@ def pad_to_fixed_length(self, data, target_len=55, pad_val=-1e9, dtype=torch.flo
283284 if L > 0 :
284285 out [i , :L ] = torch .tensor (seq [:L ], dtype = dtype )
285286 return out
286-
287+
288+ def _clip_prior_probabilities (self , policy_logits : torch .Tensor , epsilon : float , action_mask : List [np .ndarray ]) -> torch .Tensor :
289+ """
290+ Clip LLM prior probabilities to ensure minimum exploration.
291+
292+ Args:
293+ policy_logits: Log probabilities from LLM [B, A]
294+ epsilon: Minimum probability for each legal action
295+ action_mask: List of action masks for each environment
296+
297+ Returns:
298+ Clipped policy logits
299+ """
300+ # Convert logits to probabilities
301+ policy_probs = F .softmax (policy_logits , dim = - 1 )
302+
303+ # Clip probabilities for legal actions
304+ batch_size = policy_probs .shape [0 ]
305+ for i in range (batch_size ):
306+ legal_actions = action_mask [i ] == 1
307+ num_legal = legal_actions .sum ()
308+
309+ if num_legal > 0 :
310+ # Clip legal action probabilities to be at least epsilon
311+ policy_probs [i , legal_actions ] = torch .clamp (
312+ policy_probs [i , legal_actions ],
313+ min = epsilon
314+ )
315+
316+ # Renormalize to sum to 1
317+ policy_probs [i , legal_actions ] = policy_probs [i , legal_actions ] / policy_probs [i , legal_actions ].sum ()
318+
319+ # Convert back to log probabilities
320+ clipped_logits = torch .log (policy_probs + 1e-10 )
321+ return clipped_logits
322+
323+ def _compute_mixing_alpha (self , cfg : Dict ) -> float :
324+ """
325+ Compute the mixing alpha based on schedule configuration.
326+
327+ Args:
328+ cfg: Prior mixing configuration
329+
330+ Returns:
331+ Current alpha value
332+ """
333+ if not cfg .get ('alpha_schedule' ):
334+ # Fixed alpha
335+ return cfg .get ('mixing_alpha' , 0.5 )
336+
337+ schedule_type = cfg ['alpha_schedule' ]
338+ init_alpha = cfg .get ('alpha_init' , 0.8 )
339+ final_alpha = cfg .get ('alpha_final' , 0.2 )
340+ decay_steps = cfg .get ('alpha_decay_steps' , 10000 )
341+
342+ # Compute progress
343+ progress = min (self ._mixing_step / decay_steps , 1.0 )
344+
345+ if schedule_type == 'linear' :
346+ alpha = init_alpha + (final_alpha - init_alpha ) * progress
347+ elif schedule_type == 'cosine' :
348+ alpha = final_alpha + (init_alpha - final_alpha ) * 0.5 * (1 + np .cos (np .pi * progress ))
349+ elif schedule_type == 'exponential' :
350+ decay_rate = cfg .get ('alpha_decay_rate' , 0.95 )
351+ alpha = final_alpha + (init_alpha - final_alpha ) * (decay_rate ** self ._mixing_step )
352+ else :
353+ alpha = cfg .get ('mixing_alpha' , 0.5 )
354+
355+ return alpha
356+
287357 def _forward_collect (
288358 self ,
289359 data : torch .Tensor ,
@@ -324,12 +394,41 @@ def _forward_collect(
324394 prior .append (llm_prior_logprob [env_id ][action ])
325395 policy_priors .append (prior )
326396 policy_priors = self .pad_to_fixed_length (data = policy_priors , target_len = self .cfg .model .action_space_size , pad_val = - 1e9 )
327-
397+
328398 with torch .no_grad ():
329399 network_output = self ._collect_model .initial_inference (self .last_batch_obs , self .last_batch_action , data , timestep )
330400 latent_state_roots , reward_roots , pred_values , policy_logits = mz_network_output_unpack (network_output )
331401
332- network_output .policy_logits = policy_priors
402+ # ======================================================================
403+ # LLM Prior Mixing: Soft Mixing + Clip Prior
404+ # ======================================================================
405+ # Get mixing configuration from llm_config if available
406+ from zoo .jericho .priorzero .priorzero_entry_sync_ddp import llm_config
407+ mixing_cfg = llm_config .prior_mixing_cfg if hasattr (llm_config , 'prior_mixing_cfg' ) else {}
408+
409+ # Store original network policy for logging
410+ network_policy_logits = policy_logits .clone ()
411+
412+ # Apply clip prior if enabled
413+ if mixing_cfg .get ('enable_clip_prior' , False ):
414+ epsilon = mixing_cfg .get ('clip_prior_epsilon' , 0.01 )
415+ policy_priors = self ._clip_prior_probabilities (policy_priors , epsilon , action_mask )
416+
417+ # Apply soft mixing if enabled
418+ if mixing_cfg .get ('enable_soft_mixing' , False ):
419+ alpha = self ._compute_mixing_alpha (mixing_cfg )
420+ # Soft mixing: (1 - alpha) * network + alpha * LLM
421+ mixed_policy_logits = (1 - alpha ) * policy_logits + alpha * policy_priors
422+ final_policy_logits = mixed_policy_logits
423+ self ._mixing_step += 1 # Increment step for alpha scheduling
424+ else :
425+ # Hard override (original behavior)
426+ final_policy_logits = policy_priors
427+ alpha = 1.0 # For logging
428+
429+ # Update network output with final policy
430+ network_output .policy_logits = final_policy_logits
431+
333432 if not self ._cfg .mcts_ctree :
334433 raise NotImplementedError ("Python MCTS not supported for PriorZero" )
335434
@@ -338,7 +437,7 @@ def _forward_collect(
338437 # ======================================================================
339438 pred_values_np = self .value_inverse_scalar_transform_handle (pred_values ).detach ().cpu ().numpy ()
340439 latent_state_roots_np = latent_state_roots .detach ().cpu ().numpy ()
341- policy_logits = policy_priors .detach ().cpu ().numpy ().tolist ()
440+ policy_logits_for_mcts = final_policy_logits .detach ().cpu ().numpy ().tolist ()
342441
343442
344443 legal_actions = [[i for i , x in enumerate (action_mask [j ]) if x == 1 ] for j in range (active_collect_env_num )]
@@ -347,7 +446,7 @@ def _forward_collect(
347446 ).astype (np .float32 ).tolist () for j in range (active_collect_env_num )
348447 ]
349448 roots = MCTSCtree .roots (active_collect_env_num , legal_actions )
350- roots .prepare (self ._cfg .root_noise_weight , noises , reward_roots , policy_logits , to_play )
449+ roots .prepare (self ._cfg .root_noise_weight , noises , reward_roots , policy_logits_for_mcts , to_play )
351450 self ._mcts_collect .search (roots , self ._collect_model , latent_state_roots_np , to_play , timestep = timestep )
352451
353452 roots_visit_count = roots .get_distributions ()
@@ -373,8 +472,19 @@ def _forward_collect(
373472 'visit_count_distribution_entropy' : visit_count_distribution_entropy ,
374473 'searched_value' : value ,
375474 'predicted_value' : pred_values_np [i ],
376- 'predicted_policy_logits' : policy_logits [i ],
475+ 'predicted_policy_logits' : policy_logits_for_mcts [i ],
377476 'timestep' : timestep [i ],
477+ # Add mixing metrics for logging
478+ 'mixing_alpha' : alpha ,
479+ 'network_policy_entropy' : - torch .sum (
480+ F .softmax (network_policy_logits [i ], dim = - 1 ) * F .log_softmax (network_policy_logits [i ], dim = - 1 )
481+ ).item (),
482+ 'llm_policy_entropy' : - torch .sum (
483+ F .softmax (policy_priors [i ], dim = - 1 ) * F .log_softmax (policy_priors [i ], dim = - 1 )
484+ ).item (),
485+ 'mixed_policy_entropy' : - torch .sum (
486+ F .softmax (final_policy_logits [i ], dim = - 1 ) * F .log_softmax (final_policy_logits [i ], dim = - 1 )
487+ ).item (),
378488 }
379489 batch_action .append (action )
380490 self .last_batch_obs = data
0 commit comments