Skip to content

Commit 684d82b

Browse files
committed
polish(pu): polish exp-name and fix prior_mixing_cfg import bug
1 parent 47bc4b7 commit 684d82b

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class PriorZeroLLMConfig:
191191
'enable_soft_mixing': True, # Enable soft mixing instead of hard override
192192
# 'mixing_alpha': 0.5, # Weight for LLM prior (0=network only, 1=LLM only)
193193
'mixing_alpha': 0., # Weight for LLM prior (0=network only, 1=LLM only)
194-
# 'alpha_schedule': None, # 'linear', 'cosine', 'exponential', or None (fixed)
194+
'alpha_schedule': None, # 'linear', 'cosine', 'exponential', or None (fixed)
195195
# 'alpha_schedule': 'cosine', # Smooth decay
196196
'alpha_init': 0.8, # Initial alpha (high LLM influence)
197197
'alpha_final': 0.2, # Final alpha (low LLM influence)
@@ -426,6 +426,9 @@ def get_priorzero_config(
426426
llm_config.vllm_tensor_parallel_size = model_config["vllm_tensor_parallel_size"]
427427
llm_config.gpu_memory_utilization = model_config["gpu_memory_utilization"]
428428

429+
# Add prior_mixing_cfg to policy config for access in policy
430+
main_config.policy.prior_mixing_cfg = llm_config.prior_mixing_cfg
431+
429432
print(f"[Config] Model configuration applied:")
430433
print(f" - Model: {model_key}")
431434
print(f" - Path: {llm_config.model_name_or_path}")
@@ -455,16 +458,39 @@ def get_priorzero_config(
455458

456459
# Format reward info
457460
fmt_rew_str = "fmt" if llm_config.reward_func.format_reward else "nofmt"
458-
# entropy_loss_coef =
461+
462+
# Entropy loss coefficient info
463+
entropy_coef = llm_config.entropy_loss_coef
464+
if entropy_coef is None:
465+
entropy_str = "ent-off"
466+
else:
467+
entropy_str = f"ent{entropy_coef:.3f}".replace("0.", "") # 0.01 -> ent01
468+
469+
# Prior mixing info
470+
mixing_cfg = llm_config.prior_mixing_cfg
471+
if mixing_cfg.get('enable_soft_mixing', False):
472+
alpha = mixing_cfg.get('mixing_alpha', 0.5)
473+
schedule = mixing_cfg.get('alpha_schedule', None)
474+
if schedule:
475+
schedule_short = {'linear': 'lin', 'cosine': 'cos', 'exponential': 'exp'}.get(schedule, schedule[:3])
476+
mixing_str = f"mix-{schedule_short}-{alpha:.1f}"
477+
else:
478+
mixing_str = f"mix-fix-{alpha:.1f}"
479+
else:
480+
mixing_str = "mix-hard"
481+
482+
# Clip prior info
483+
if mixing_cfg.get('enable_clip_prior', False):
484+
clip_eps = mixing_cfg.get('clip_prior_epsilon', 0.01)
485+
clip_str = f"clip{clip_eps:.2f}".replace("0.", "") # 0.01 -> clip01
486+
else:
487+
clip_str = "noclip"
459488

460489
# Build exp_name
461-
# exp_name = (
462-
# f"data_priorzero/pz_{env_id}_{model_key}_"
463-
# f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_pel{entropy_loss_coef}_llm-mix-0-true_seed{seed}" # TODO
464-
# )
465490
exp_name = (
466491
f"data_priorzero/pz_{env_id}_{model_key}_"
467-
f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_pel001_llm-mix-0-true_seed{seed}" # TODO
492+
f"{cot_str}_{adv_type_short}_{prior_temp_str}_{fmt_rew_str}_"
493+
f"{entropy_str}_{mixing_str}_{clip_str}_seed{seed}"
468494
)
469495

470496
# Update config with generated exp_name

zoo/jericho/priorzero/priorzero_policy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,12 +399,14 @@ def _forward_collect(
399399
network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep)
400400
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
401401

402+
# Move policy_priors to the same device as policy_logits
403+
policy_priors = policy_priors.to(policy_logits.device)
404+
402405
# ======================================================================
403406
# LLM Prior Mixing: Soft Mixing + Clip Prior
404407
# ======================================================================
405-
# Get mixing configuration from llm_config if available
406-
from zoo.jericho.priorzero.priorzero_entry_sync_ddp import llm_config
407-
mixing_cfg = llm_config.prior_mixing_cfg if hasattr(llm_config, 'prior_mixing_cfg') else {}
408+
# Get mixing configuration from policy config
409+
mixing_cfg = self._cfg.get('prior_mixing_cfg', {})
408410

409411
# Store original network policy for logging
410412
network_policy_logits = policy_logits.clone()

0 commit comments

Comments
 (0)