Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2e44edf
polish(pu): use collect/eval_num_simulations for collect/eval phases
zjowowen Jul 31, 2025
c5a4b16
tmp polish(pu): polish longrun config
zjowowen Aug 1, 2025
ca30789
polish(pu): polish longrun config
zjowowen Aug 1, 2025
d4fbc51
Merge branch 'dev-longrun' of https://github.com/opendilab/LightZero …
zjowowen Aug 4, 2025
037918b
fix(pu): fix game_segment_idx in _sample_orig_reanalyze_batch
zjowowen Aug 5, 2025
bb89bc0
tmp
puyuan1996 Aug 5, 2025
cd1ad91
feature(pu): add different norm option for unizero encoder
Aug 5, 2025
d1bb16c
feature(pu): add gradient_scale option in unizero, polish unizero config
Aug 15, 2025
bb0845d
polish(pu): polish buffer remove method
Aug 18, 2025
50bd3c0
fix(pu): fix unizero reset_collect/eval kv_cache bug!!!
Aug 18, 2025
2eb6d05
feature(pu): add toy env to test unizero world_model
Aug 19, 2025
fc83e9a
polish(pu): polish unizero config
Aug 24, 2025
049883d
fix(pu): fix kv_shared_pool init and recur
Aug 24, 2025
22cba86
polish(pu): rm ununsed files
Aug 24, 2025
b7c7016
fix(pu): fix init_recur kv share pool index
Aug 25, 2025
961f3be
fix(pu): fix recur kv pool index compatibility
Aug 25, 2025
feb6a01
polish(pu): polish weight decay and add latent_norm_loss
Aug 26, 2025
a2d37f8
feature(pu): add PER for UniZero
Aug 28, 2025
83e5933
sync code
Aug 29, 2025
aa7b630
Merge branch 'dev-longrun' of https://github.com/opendilab/LightZero …
puyuan1996 Sep 1, 2025
122fea6
polish(pu): add ln in value/reward head
zjowowen Sep 2, 2025
b321f56
polish(pu): polish unizero episode config
zjowowen Sep 4, 2025
19315f7
tmp
zjowowen Sep 6, 2025
35a98c4
feature(pu): add muzero save_buffer and unizero load buffer option
zjowowen Sep 10, 2025
04ee4c5
test(pu): add tsne analyze utils, add sgd and adamw-mix-lr option
zjowowen Sep 10, 2025
38e3312
feature(pu): add encoder latent norm clip option
zjowowen Sep 10, 2025
04d15ec
tmp
zjowowen Sep 13, 2025
bf92973
polish(pu): add reinit_value_head option
zjowowen Sep 15, 2025
a855c5d
polish(pu): polish reinit_prediction_heads option
zjowowen Sep 15, 2025
4c8d8e1
feature(pu): add action_latent_norm option, add head_grad_norm and ba…
zjowowen Sep 15, 2025
a07ecb5
feature(pu): add head_logit_weight_clip
zjowowen Sep 16, 2025
7b9174c
feature(pu): add head_policy_logit_weight_clip
zjowowen Sep 16, 2025
9ee31bb
fix(pu): fix train_unizero_from_buffer
zjowowen Sep 16, 2025
8dd82f9
fix(pu): fix muzero save buffer from ckpt
zjowowen Sep 16, 2025
0c90b31
polish(pu): polish unizero config
zjowowen Sep 16, 2025
507a9d5
fix(pu): use muzero style head for unizero
zjowowen Sep 17, 2025
c25ca11
feature(pu): add use_temperature_scaling option
zjowowen Sep 18, 2025
6ae7191
fix(pu): fix collect/eval/compute_target in use_temperature_scaling o…
zjowowen Sep 18, 2025
0b6cffd
fix(pu): fix norm_type in encoder, add encoder-dinov2 option
zjowowen Sep 18, 2025
807cdcd
fix(pu): add value/reward policy label smoothgit add lzero zoo!
zjowowen Sep 18, 2025
9c84e94
fix(pu): fix policy label smooth
zjowowen Sep 18, 2025
2170b8b
feature(pu): add entropy_explorer.py
zjowowen Sep 18, 2025
b582814
feature(pu): add monitor_model_norms option, use different weight-decay
zjowowen Sep 18, 2025
45c069f
polish(pu): polish config
zjowowen Sep 19, 2025
cac4b2c
polish(pu): polish lr_cos_decay loss_weight
zjowowen Sep 20, 2025
8c3c467
feature(pu): add adaptive_entropy_alpha and encoder_clip_annealing op…
puyuan1996 Sep 23, 2025
86c1a11
fix(pu): fix adaptive_entropy_alpha
puyuan1996 Sep 23, 2025
1298037
polish(pu):polish target policy entropy option
puyuan1996 Sep 23, 2025
57a0e76
tmp
tAnGjIa520 Oct 2, 2025
26bdc0e
fix(pu): fix num_res_blocks, num_channels config bug
tAnGjIa520 Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from .train_alphazero import train_alphazero
from .train_muzero import train_muzero
from .train_muzero_segment import train_muzero_segment
from .train_muzero_segment_save_buffer import train_muzero_segment_save_buffer
from .train_muzero_segment_save_buffer_from_ckpt import train_muzero_segment_save_buffer_from_ckpt


from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .train_unizero_segment_from_buffer import train_unizero_segment_from_buffer

from .utils import *
10 changes: 9 additions & 1 deletion lzero/entry/train_muzero_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
timer = EasyTimer()




def train_muzero_segment(
input_cfg: Tuple[dict, dict],
seed: int = 0,
Expand Down Expand Up @@ -88,7 +90,9 @@ def train_muzero_segment(

# load pretrained model
if model_path is not None:
logging.info(f'Loading model from {model_path} begin...')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Loading model from {model_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
Expand Down Expand Up @@ -144,6 +148,8 @@ def train_muzero_segment(
train_epoch = 0
reanalyze_batch_size = cfg.policy.reanalyze_batch_size



while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger)
Expand All @@ -169,7 +175,9 @@ def train_muzero_segment(
collect_kwargs['epsilon'] = 0.0

# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
if learner.train_iter==0 or evaluator.should_eval(learner.train_iter):
# if evaluator.should_eval(learner.train_iter):

if cfg.policy.eval_offline:
eval_train_iter_list.append(learner.train_iter)
eval_train_envstep_list.append(collector.envstep)
Expand Down
260 changes: 260 additions & 0 deletions lzero/entry/train_muzero_segment_orig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import logging
import os
from functools import partial
from typing import Optional, Tuple

import torch
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import EasyTimer
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from .utils import random_collect, calculate_update_per_collect

timer = EasyTimer()




def train_muzero_segment(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
The train entry for MCTS+RL algorithms (with muzero_segment_collector and buffer reanalyze trick), including MuZero, EfficientZero, Sampled MuZero, Sampled EfficientZero, Gumbel MuZero, StochasticMuZero.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'"

if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']:
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_efficientzero':
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_muzero':
from lzero.mcts import SampledMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gumbel_muzero':
from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'stochastic_muzero':
from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
else:
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

if cfg.policy.eval_offline:
cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
if model_path is not None:
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = cfg.policy
batch_size = policy_config.batch_size
# specific game buffer for MCTS+RL algorithms
replay_buffer = GameBuffer(policy_config)
collector = Collector(
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
evaluator = Evaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)

# ==============================================================
# Main loop
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')

if cfg.policy.update_per_collect is not None:
update_per_collect = cfg.policy.update_per_collect

# The purpose of collecting random data before training:
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
if cfg.policy.eval_offline:
eval_train_iter_list = []
eval_train_envstep_list = []

# Evaluate the random agent
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

buffer_reanalyze_count = 0
train_epoch = 0
reanalyze_batch_size = cfg.policy.reanalyze_batch_size



while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
)

if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
else:
collect_kwargs['epsilon'] = 0.0

# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
if cfg.policy.eval_offline:
eval_train_iter_list.append(learner.train_iter)
eval_train_envstep_list.append(collector.envstep)
else:
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data)

# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()

# Periodically reanalyze buffer
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
else:
# Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch
if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

# Learn policy from collected data.
for i in range(update_per_collect):

if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(
reanalyze_batch_size / cfg.policy.reanalyze_partition):
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')

# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

train_epoch += 1

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
if cfg.policy.eval_offline:
logging.info(f'eval offline beginning...')
ckpt_dirname = './{}/ckpt'.format(learner.exp_name)
# Evaluate the performance of the pretrained model.
for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list):
ckpt_name = 'iteration_{}.pth.tar'.format(train_iter)
ckpt_path = os.path.join(ckpt_dirname, ckpt_name)
# load the ckpt of pretrained model
policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device))
stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep)
logging.info(
f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}')
logging.info(f'eval offline finished!')
break

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
Loading
Loading