Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .train_unizero_segment_with_reward_model import train_unizero_segment_with_reward_model
from .utils import *
258 changes: 258 additions & 0 deletions lzero/entry/train_unizero_segment_with_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import logging
import os
from functools import partial
from typing import Tuple, Optional

import torch
import wandb
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import EasyTimer
from ding.utils import set_pkg_seed, get_rank, get_world_size
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from lzero.reward_model.rnd_reward_model import RNDRewardModel
from .utils import random_collect, calculate_update_per_collect

timer = EasyTimer()

def train_unizero_segment_with_reward_model(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy':
"""
Overview:
The train entry for UniZero (with muzero_segment_collector and buffer reanalyze trick), proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
assert cfg.policy.use_rnd_model, "cfg.policy.use_rnd_model must be True to use RND reward model"

# Import the correct GameBuffer class based on the policy type
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}

GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])

# Set device based on CUDA availability
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')

# Compile the configuration
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# Load pretrained model if specified
if model_path is not None:
logging.info(f'Loading model from {model_path} begin...')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Loading model from {model_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander
tb_logger = None
if get_rank() == 0:
tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# MCTS+RL algorithms related core code
policy_config = cfg.policy
replay_buffer = GameBuffer(policy_config)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
policy_config=policy_config)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)



# ==============================================================
# 新增: 初始化 RND 奖励模型
# RNDRewardModel 需要策略模型中的表征网络(作为预测器)和目标表征网络(作为固定目标)
# 对于 UniZero,tokenizer 扮演了表征网络的功能。
# ==============================================================
reward_model = RNDRewardModel(
config=cfg.reward_model,
device=policy.collect_mode.get_attribute('device'),
tb_logger=tb_logger,
exp_name=cfg.exp_name,
representation_network=policy._learn_model.representation_network,
target_representation_network=policy._target_model_for_intrinsic_reward.representation_network,
use_momentum_representation_network=cfg.policy.use_momentum_representation_network,
bp_update_sync=cfg.policy.bp_update_sync,
multi_gpu=cfg.policy.multi_gpu,
)


# Learner's before_run hook
learner.call_hook('before_run')

if cfg.policy.use_wandb and get_rank() == 0:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# Collect random data before training
if cfg.policy.random_collect_data:
random_data = random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
try:
reward_model.warmup_with_random_segments(random_data)
except Exception as e:
logging.exception(f"Failed to warm up RND normalization using random data: {e}")
raise
batch_size = policy._cfg.batch_size

buffer_reanalyze_count = 0
train_epoch = 0
reanalyze_batch_size = cfg.policy.reanalyze_batch_size

if cfg.policy.multi_gpu:
# Get current world size and rank
world_size = get_world_size()
rank = get_rank()
else:
world_size = 1
rank = 0

while True:
# Log buffer memory usage
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)

# Set temperature for visit count distributions
collect_kwargs = {
'temperature': visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
}

# Configure epsilon for epsilon-greedy exploration
if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep, reward_model=reward_model)
if stop:
break

# Collect new data
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data, world_size)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()

# Periodically reanalyze buffer
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
else:
# Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch
if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

# Train the policy if sufficient data is available
if collector.envstep > cfg.policy.train_start_after_envsteps:
if cfg.policy.sample_type == 'episode':
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
else:
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size
if not data_sufficient:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
)
continue

for i in range(update_per_collect):
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

train_data_augmented = reward_model.estimate(train_data)
train_data_augmented.append(learner.train_iter)

log_vars = learner.train(train_data_augmented, collector.envstep)
reward_model.train_with_policy_batch(train_data)
Copy link
Collaborator

@puyuan1996 puyuan1996 Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该reward_model先训一些iters 然后unizero用训好的rnd网络估计融合奖励 再去训unizero的网络,目前这个版本相当于融合奖励每个迭代都在变化,对于unizero这边的学习来说太不平稳了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对,目前加上了之前讨论的那个参数自适应,初始阶段为0,一段时间后慢慢升上来,这样的话初始阶段相当于只是训练了RND网络,但是没用到内在奖励

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前新跑的都是用了这个方法吗

logging.info(f'[{i}/{update_per_collect}]: learner and reward_model ended training step.')

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

train_epoch += 1
policy.recompute_pos_emb_diff_and_clear_cache()

# Check stopping criteria
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

learner.call_hook('after_run')
if cfg.policy.use_wandb:
wandb.finish()
return policy
7 changes: 4 additions & 3 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def random_collect(
collector_env: 'BaseEnvManager', # noqa
replay_buffer: 'IBuffer', # noqa
postprocess_data_fn: Optional[Callable] = None
) -> None: # noqa
assert policy_cfg.random_collect_episode_num > 0
) -> list: # noqa
assert policy_cfg.random_collect_data, "random_collect_data should be True."

random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space)
# set the policy to random policy
Expand All @@ -159,7 +159,7 @@ def random_collect(
collect_kwargs = {'temperature': 1, 'epsilon': 0.0}

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0,
new_data = collector.collect(train_iter=0,
policy_kwargs=collect_kwargs)

if postprocess_data_fn is not None:
Expand All @@ -172,6 +172,7 @@ def random_collect(

# restore the policy
collector.reset_policy(policy.collect_mode)
return new_data


def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
Expand Down
31 changes: 31 additions & 0 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,34 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
batch_target_values = np.asarray(batch_target_values)

return batch_rewards, batch_target_values

def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None:
"""
Overview:
Update the priority of training data.
Arguments:
- train_data (:obj:`List[np.ndarray]`): training data to be updated priority.
- batch_priorities (:obj:`np.ndarray`): priorities to update to.
NOTE:
train_data = [current_batch, target_batch]
current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]
"""
# TODO: NOTE: -4 is batch_index_list
indices = train_data[0][-4]
metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities}
# only update the priorities for data still in replay buffer
for i in range(len(indices)):
# ==================== START OF FINAL FIX ====================

# FIX 1: Handle ValueError by using the first timestamp of the segment for comparison.
first_transition_time = metas['make_time'][i][0]

if first_transition_time > self.clear_time:
# FIX 2: Handle IndexError by converting the float index to an integer before use.
idx = int(indices[i])
prio = metas['batch_priorities'][i]

# Now, idx is a valid integer index.
self.game_pos_priorities[idx] = prio

# ===================== END OF FINAL FIX =====================
7 changes: 3 additions & 4 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,11 +641,11 @@ def __init__(
self.embedding_dim = embedding_dim

if self.observation_shape[1] == 64:
self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)
self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False)

elif self.observation_shape[1] in [84, 96]:
self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)

self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False)
self.final_norm_option_in_encoder = final_norm_option_in_encoder
if self.final_norm_option_in_encoder == 'LayerNorm':
self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5)
Expand Down Expand Up @@ -678,7 +678,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

x = x.view(-1, self.embedding_dim)

# NOTE: very important for training stability.
x = self.final_norm(x)

return x
Expand Down
Loading