Skip to content

Commit 47bc4b7

Browse files
committed
feature(pu): add entropy_loss_coef and prior_mixing_cfg
1 parent ebbc9d8 commit 47bc4b7

File tree

3 files changed

+167
-12
lines changed

3 files changed

+167
-12
lines changed

zoo/jericho/priorzero/models/actor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
224224
micro_batch['action_mask'],
225225
attention_mask=micro_batch['attention_mask'],
226226
return_output=True,
227-
logits_to_keep=logits_to_keep,
227+
return_entropy=self.args.entropy_loss_coef is not None,
228+
logits_to_keep=logits_to_keep,
228229
)
229230
actor_loss, clipfrac, clip_ratio, approx_kl, vllm_kl = self.policy_loss(
230231
action_log_probs,
@@ -242,8 +243,20 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
242243
kl_loss = masked_mean(kl, micro_batch["action_mask"])
243244
else:
244245
kl_loss = 0.0
245-
246+
247+
# Entropy loss for exploration bonus
248+
if self.args.entropy_loss_coef is not None:
249+
# Extract entropy for action tokens only
250+
# Note: output.entropy is already [:, :-1] from Actor.forward (line 89)
251+
# So we extract the last action_mask.shape[1] tokens
252+
entropy = output.entropy[:, -micro_batch['action_mask'].shape[1]:]
253+
entropy_loss = masked_mean(entropy, micro_batch['action_mask'])
254+
else:
255+
entropy_loss = 0.0
256+
246257
loss = actor_loss + kl_loss * float(kl_ctl.value)
258+
if self.args.entropy_loss_coef is not None and self.args.entropy_loss_coef != 0:
259+
loss -= entropy_loss * self.args.entropy_loss_coef
247260

248261
self.strategy.backward(loss, self.actor, self.actor_optim)
249262

@@ -326,6 +339,13 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
326339
status["cur_refer_kl"] = kl_loss.detach().float().mean().item()
327340
else:
328341
status["cur_refer_kl"] = float(kl_loss)
342+
343+
# Add entropy loss logging
344+
if self.args.entropy_loss_coef is not None:
345+
if isinstance(entropy_loss, torch.Tensor):
346+
status["entropy_loss"] = entropy_loss.detach().float().mean().item()
347+
else:
348+
status["entropy_loss"] = float(entropy_loss)
329349

330350
status = self.strategy.all_reduce(status)
331351
status_list.append(status)

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,25 @@ class PriorZeroLLMConfig:
181181
eps_clip_low_high: Tuple[float, float] = (0.2, 0.2)
182182
rft_kl_coef: float = 0.01
183183
kl_estimator: str = "k3"
184-
184+
185+
# Entropy loss for exploration bonus
186+
entropy_loss_coef: Optional[float] = 0.01 # None = disabled, typical values: 0.001-0.01
187+
# entropy_loss_coef: Optional[float] = None # None = disabled, typical values: 0.001-0.01
188+
189+
# LLM Prior Mixing Configuration
190+
prior_mixing_cfg: Optional[EasyDict] = field(default_factory=lambda: EasyDict({
191+
'enable_soft_mixing': True, # Enable soft mixing instead of hard override
192+
# 'mixing_alpha': 0.5, # Weight for LLM prior (0=network only, 1=LLM only)
193+
'mixing_alpha': 0., # Weight for LLM prior (0=network only, 1=LLM only)
194+
# 'alpha_schedule': None, # 'linear', 'cosine', 'exponential', or None (fixed)
195+
# 'alpha_schedule': 'cosine', # Smooth decay
196+
'alpha_init': 0.8, # Initial alpha (high LLM influence)
197+
'alpha_final': 0.2, # Final alpha (low LLM influence)
198+
'alpha_decay_steps': 10000, # Steps to decay from init to final
199+
'enable_clip_prior': True, # Enable clipping of LLM prior probabilities
200+
'clip_prior_epsilon': 0.01, # Minimum probability for each action (exploration)
201+
}))
202+
185203
train_llm_after_wm_warm_step: int = int(1e2) # TODO
186204
value_norm_cfg: Optional[EasyDict] = field(default_factory=lambda: EasyDict({
187205
'enable_stability_optimizer': True,
@@ -327,7 +345,9 @@ def get_priorzero_config(
327345
n_episode=n_episode,
328346
train_start_after_envsteps=0,
329347
replay_buffer_size=replay_buffer_size,
330-
eval_freq=int(3e4),
348+
# eval_freq=int(3e4),
349+
eval_freq=int(1e3), # TODO
350+
# eval_freq=int(2), # TODO
331351
collector_env_num=collector_env_num,
332352
evaluator_env_num=evaluator_env_num,
333353
buffer_reanalyze_freq=1 / 1000000,
@@ -435,11 +455,16 @@ def get_priorzero_config(
435455

436456
# Format reward info
437457
fmt_rew_str = "fmt" if llm_config.reward_func.format_reward else "nofmt"
458+
# entropy_loss_coef =
438459

439460
# Build exp_name
461+
# exp_name = (
462+
# f"data_priorzero/pz_{env_id}_{model_key}_"
463+
# f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_pel{entropy_loss_coef}_llm-mix-0-true_seed{seed}" # TODO
464+
# )
440465
exp_name = (
441466
f"data_priorzero/pz_{env_id}_{model_key}_"
442-
f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_seed{seed}"
467+
f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_pel001_llm-mix-0-true_seed{seed}" # TODO
443468
)
444469

445470
# Update config with generated exp_name

zoo/jericho/priorzero/priorzero_policy.py

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626

2727
@POLICY_REGISTRY.register('priorzero', force_overwrite=True)
2828
class PriorZeroPolicy(OriginalUniZeroPolicy):
29-
def __init__(self, cfg: Dict, model: torch.nn.Module = None, enable_field: List[str] = None, **kwargs):
29+
def __init__(self, cfg: Dict, model: torch.nn.Module = None, enable_field: List[str] = None, **kwargs):
3030
super().__init__(cfg, model, enable_field)
31+
self._mixing_step = 0 # Track steps for alpha scheduling
3132

3233
def _init_learn(self) -> None:
3334
super()._init_learn()
@@ -283,7 +284,76 @@ def pad_to_fixed_length(self, data, target_len=55, pad_val=-1e9, dtype=torch.flo
283284
if L > 0:
284285
out[i, :L] = torch.tensor(seq[:L], dtype=dtype)
285286
return out
286-
287+
288+
def _clip_prior_probabilities(self, policy_logits: torch.Tensor, epsilon: float, action_mask: List[np.ndarray]) -> torch.Tensor:
289+
"""
290+
Clip LLM prior probabilities to ensure minimum exploration.
291+
292+
Args:
293+
policy_logits: Log probabilities from LLM [B, A]
294+
epsilon: Minimum probability for each legal action
295+
action_mask: List of action masks for each environment
296+
297+
Returns:
298+
Clipped policy logits
299+
"""
300+
# Convert logits to probabilities
301+
policy_probs = F.softmax(policy_logits, dim=-1)
302+
303+
# Clip probabilities for legal actions
304+
batch_size = policy_probs.shape[0]
305+
for i in range(batch_size):
306+
legal_actions = action_mask[i] == 1
307+
num_legal = legal_actions.sum()
308+
309+
if num_legal > 0:
310+
# Clip legal action probabilities to be at least epsilon
311+
policy_probs[i, legal_actions] = torch.clamp(
312+
policy_probs[i, legal_actions],
313+
min=epsilon
314+
)
315+
316+
# Renormalize to sum to 1
317+
policy_probs[i, legal_actions] = policy_probs[i, legal_actions] / policy_probs[i, legal_actions].sum()
318+
319+
# Convert back to log probabilities
320+
clipped_logits = torch.log(policy_probs + 1e-10)
321+
return clipped_logits
322+
323+
def _compute_mixing_alpha(self, cfg: Dict) -> float:
324+
"""
325+
Compute the mixing alpha based on schedule configuration.
326+
327+
Args:
328+
cfg: Prior mixing configuration
329+
330+
Returns:
331+
Current alpha value
332+
"""
333+
if not cfg.get('alpha_schedule'):
334+
# Fixed alpha
335+
return cfg.get('mixing_alpha', 0.5)
336+
337+
schedule_type = cfg['alpha_schedule']
338+
init_alpha = cfg.get('alpha_init', 0.8)
339+
final_alpha = cfg.get('alpha_final', 0.2)
340+
decay_steps = cfg.get('alpha_decay_steps', 10000)
341+
342+
# Compute progress
343+
progress = min(self._mixing_step / decay_steps, 1.0)
344+
345+
if schedule_type == 'linear':
346+
alpha = init_alpha + (final_alpha - init_alpha) * progress
347+
elif schedule_type == 'cosine':
348+
alpha = final_alpha + (init_alpha - final_alpha) * 0.5 * (1 + np.cos(np.pi * progress))
349+
elif schedule_type == 'exponential':
350+
decay_rate = cfg.get('alpha_decay_rate', 0.95)
351+
alpha = final_alpha + (init_alpha - final_alpha) * (decay_rate ** self._mixing_step)
352+
else:
353+
alpha = cfg.get('mixing_alpha', 0.5)
354+
355+
return alpha
356+
287357
def _forward_collect(
288358
self,
289359
data: torch.Tensor,
@@ -324,12 +394,41 @@ def _forward_collect(
324394
prior.append(llm_prior_logprob[env_id][action])
325395
policy_priors.append(prior)
326396
policy_priors = self.pad_to_fixed_length(data=policy_priors, target_len=self.cfg.model.action_space_size, pad_val=-1e9)
327-
397+
328398
with torch.no_grad():
329399
network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep)
330400
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
331401

332-
network_output.policy_logits = policy_priors
402+
# ======================================================================
403+
# LLM Prior Mixing: Soft Mixing + Clip Prior
404+
# ======================================================================
405+
# Get mixing configuration from llm_config if available
406+
from zoo.jericho.priorzero.priorzero_entry_sync_ddp import llm_config
407+
mixing_cfg = llm_config.prior_mixing_cfg if hasattr(llm_config, 'prior_mixing_cfg') else {}
408+
409+
# Store original network policy for logging
410+
network_policy_logits = policy_logits.clone()
411+
412+
# Apply clip prior if enabled
413+
if mixing_cfg.get('enable_clip_prior', False):
414+
epsilon = mixing_cfg.get('clip_prior_epsilon', 0.01)
415+
policy_priors = self._clip_prior_probabilities(policy_priors, epsilon, action_mask)
416+
417+
# Apply soft mixing if enabled
418+
if mixing_cfg.get('enable_soft_mixing', False):
419+
alpha = self._compute_mixing_alpha(mixing_cfg)
420+
# Soft mixing: (1 - alpha) * network + alpha * LLM
421+
mixed_policy_logits = (1 - alpha) * policy_logits + alpha * policy_priors
422+
final_policy_logits = mixed_policy_logits
423+
self._mixing_step += 1 # Increment step for alpha scheduling
424+
else:
425+
# Hard override (original behavior)
426+
final_policy_logits = policy_priors
427+
alpha = 1.0 # For logging
428+
429+
# Update network output with final policy
430+
network_output.policy_logits = final_policy_logits
431+
333432
if not self._cfg.mcts_ctree:
334433
raise NotImplementedError("Python MCTS not supported for PriorZero")
335434

@@ -338,7 +437,7 @@ def _forward_collect(
338437
# ======================================================================
339438
pred_values_np = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy()
340439
latent_state_roots_np = latent_state_roots.detach().cpu().numpy()
341-
policy_logits = policy_priors.detach().cpu().numpy().tolist()
440+
policy_logits_for_mcts = final_policy_logits.detach().cpu().numpy().tolist()
342441

343442

344443
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)]
@@ -347,7 +446,7 @@ def _forward_collect(
347446
).astype(np.float32).tolist() for j in range(active_collect_env_num)
348447
]
349448
roots = MCTSCtree.roots(active_collect_env_num, legal_actions)
350-
roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play)
449+
roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits_for_mcts, to_play)
351450
self._mcts_collect.search(roots, self._collect_model, latent_state_roots_np, to_play, timestep=timestep)
352451

353452
roots_visit_count = roots.get_distributions()
@@ -373,8 +472,19 @@ def _forward_collect(
373472
'visit_count_distribution_entropy': visit_count_distribution_entropy,
374473
'searched_value': value,
375474
'predicted_value': pred_values_np[i],
376-
'predicted_policy_logits': policy_logits[i],
475+
'predicted_policy_logits': policy_logits_for_mcts[i],
377476
'timestep': timestep[i],
477+
# Add mixing metrics for logging
478+
'mixing_alpha': alpha,
479+
'network_policy_entropy': -torch.sum(
480+
F.softmax(network_policy_logits[i], dim=-1) * F.log_softmax(network_policy_logits[i], dim=-1)
481+
).item(),
482+
'llm_policy_entropy': -torch.sum(
483+
F.softmax(policy_priors[i], dim=-1) * F.log_softmax(policy_priors[i], dim=-1)
484+
).item(),
485+
'mixed_policy_entropy': -torch.sum(
486+
F.softmax(final_policy_logits[i], dim=-1) * F.log_softmax(final_policy_logits[i], dim=-1)
487+
).item(),
378488
}
379489
batch_action.append(action)
380490
self.last_batch_obs = data

0 commit comments

Comments
 (0)