diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index f17126527..18f92e062 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -5,9 +5,15 @@ from .train_alphazero import train_alphazero from .train_muzero import train_muzero from .train_muzero_segment import train_muzero_segment +from .train_muzero_segment_save_buffer import train_muzero_segment_save_buffer +from .train_muzero_segment_save_buffer_from_ckpt import train_muzero_segment_save_buffer_from_ckpt + + from .train_muzero_with_gym_env import train_muzero_with_gym_env from .train_muzero_with_reward_model import train_muzero_with_reward_model from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment +from .train_unizero_segment_from_buffer import train_unizero_segment_from_buffer + from .utils import * diff --git a/lzero/entry/train_muzero_segment.py b/lzero/entry/train_muzero_segment.py index 4e9809e05..9f0b5b29d 100644 --- a/lzero/entry/train_muzero_segment.py +++ b/lzero/entry/train_muzero_segment.py @@ -24,6 +24,8 @@ timer = EasyTimer() + + def train_muzero_segment( input_cfg: Tuple[dict, dict], seed: int = 0, @@ -88,7 +90,9 @@ def train_muzero_segment( # load pretrained model if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') # Create worker components: learner, collector, evaluator, replay buffer, commander. tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None @@ -144,6 +148,8 @@ def train_muzero_segment( train_epoch = 0 reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger) @@ -169,7 +175,9 @@ def train_muzero_segment( collect_kwargs['epsilon'] = 0.0 # Evaluate policy performance. - if evaluator.should_eval(learner.train_iter): + if learner.train_iter==0 or evaluator.should_eval(learner.train_iter): + # if evaluator.should_eval(learner.train_iter): + if cfg.policy.eval_offline: eval_train_iter_list.append(learner.train_iter) eval_train_envstep_list.append(collector.envstep) diff --git a/lzero/entry/train_muzero_segment_orig.py b/lzero/entry/train_muzero_segment_orig.py new file mode 100644 index 000000000..b3f408d08 --- /dev/null +++ b/lzero/entry/train_muzero_segment_orig.py @@ -0,0 +1,260 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from .utils import random_collect, calculate_update_per_collect + +timer = EasyTimer() + + + + +def train_muzero_segment( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms (with muzero_segment_collector and buffer reanalyze trick), including MuZero, EfficientZero, Sampled MuZero, Sampled EfficientZero, Gumbel MuZero, StochasticMuZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'" + + if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']: + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_muzero': + from lzero.mcts import SampledMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + if cfg.policy.eval_offline: + cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + if cfg.policy.eval_offline: + eval_train_iter_list = [] + eval_train_envstep_list = [] + + # Evaluate the random agent + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + if cfg.policy.eval_offline: + eval_train_iter_list.append(learner.train_iter) + eval_train_envstep_list.append(collector.envstep) + else: + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + update_per_collect = calculate_update_per_collect(cfg, new_data) + + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze buffer + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch + if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + with timer: + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + # Learn policy from collected data. + for i in range(update_per_collect): + + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + train_epoch += 1 + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + if cfg.policy.eval_offline: + logging.info(f'eval offline beginning...') + ckpt_dirname = './{}/ckpt'.format(learner.exp_name) + # Evaluate the performance of the pretrained model. + for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list): + ckpt_name = 'iteration_{}.pth.tar'.format(train_iter) + ckpt_path = os.path.join(ckpt_dirname, ckpt_name) + # load the ckpt of pretrained model + policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device)) + stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) + logging.info( + f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}') + logging.info(f'eval offline finished!') + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/lzero/entry/train_muzero_segment_save_buffer.py b/lzero/entry/train_muzero_segment_save_buffer.py new file mode 100644 index 000000000..a206d4631 --- /dev/null +++ b/lzero/entry/train_muzero_segment_save_buffer.py @@ -0,0 +1,333 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple, Dict, Any + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from .utils import random_collect, calculate_update_per_collect +from easydict import EasyDict +timer = EasyTimer() + +# # ============================================================== +# # 开始: 定义将GameSegment转换为纯NumPy字典的辅助函数 +# # ============================================================== +# def convert_game_segment_to_numpy(game_segment: Any) -> Dict[str, Any]: +# """ +# 将一个GameSegment对象转换为一个只包含Python基本类型和NumPy数组的字典。 +# 这移除了所有PyTorch张量,使其可以被安全地、跨版本地序列化。 +# """ +# numpy_dict = {} +# for attr, value in game_segment.__dict__.items(): +# if isinstance(value, torch.Tensor): +# # 将Tensor转换为CPU上的NumPy数组 +# numpy_dict[attr] = value.cpu().numpy() +# elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor): +# # 处理Tensor列表 +# numpy_dict[attr] = [v.cpu().numpy() for v in value] +# else: +# # 其他类型直接复制 +# numpy_dict[attr] = value +# return numpy_dict +# # ============================================================== +# # 结束: 辅助函数 +# # ============================================================== + +# ============================================================== +# 开始: 定义终极的、无死角的递归数据净化函数 +# ============================================================== +def deep_to_serializable(data: Any) -> Any: + """ + 递归地将一个复杂的数据结构转换为完全可序列化的格式。 + - torch.Tensor -> numpy.ndarray + - easydict.EasyDict -> dict + - 任何带有 __dict__ 属性的自定义对象 (如 GameSegment) -> dict + - 递归处理 list, tuple, dict 的内容。 + """ + if isinstance(data, torch.Tensor): + return data.cpu().numpy() + + if isinstance(data, (dict, EasyDict)): + return {k: deep_to_serializable(v) for k, v in data.items()} + + if isinstance(data, (list, tuple)): + return type(data)(deep_to_serializable(item) for item in data) + + # ==================== 这是新增的、最关键的处理逻辑 ==================== + # 检查它是否是一个自定义类的实例 (而不是基本类型且拥有 __dict__) + if hasattr(data, '__dict__'): + # 将对象转换为其属性字典,并对这个字典进行递归净化 + return deep_to_serializable(data.__dict__) + # ==================================================================== + + # 对于其他基本类型 (int, float, str, bool, None, numpy.ndarray),直接返回 + return data + +def train_muzero_segment_save_buffer( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms (with muzero_segment_collector and buffer reanalyze trick), including MuZero, EfficientZero, Sampled MuZero, Sampled EfficientZero, Gumbel MuZero, StochasticMuZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'" + + GameBuffer = None + if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']: + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_muzero': + from lzero.mcts import SampledMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + if cfg.policy.eval_offline: + cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + batch_size = policy_config.batch_size + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + if cfg.policy.eval_offline: + eval_train_iter_list = [] + eval_train_envstep_list = [] + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + save_buffer_interval = 100000 # TODO: 100k + # save_buffer_interval = 2 + + last_save_iter = 0 + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + if evaluator.should_eval(learner.train_iter): + if cfg.policy.eval_offline: + eval_train_iter_list.append(learner.train_iter) + eval_train_envstep_list.append(collector.envstep) + else: + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + update_per_collect = calculate_update_per_collect(cfg, new_data) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + for i in range(update_per_collect): + if cfg.policy.buffer_reanalyze_freq >= 1: + if i > 0 and i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + # ============================================================== + # 开始: 【最终修复】使用递归净化函数保存100%纯数据 + # ============================================================== + current_iter = learner.train_iter + if current_iter // save_buffer_interval > last_save_iter // save_buffer_interval: + current_milestone = (current_iter // save_buffer_interval) * save_buffer_interval + + buffer_save_dir = os.path.join(cfg.exp_name, 'game_buffers') + os.makedirs(buffer_save_dir, exist_ok=True) + file_path = os.path.join(buffer_save_dir, f'muzero_game_buffer_iter_{current_milestone}.pth') + + logging.info(f"达到训练迭代次数 {current_milestone},正在深度净化并保存 Game Buffer...") + + try: + # 1. 创建一个包含所有核心数据的原始字典 + # 注意:这里我们不再需要手动转换任何东西 + buffer_data_to_save_raw = { + 'cfg': replay_buffer._cfg, + 'game_segment_buffer': replay_buffer.game_segment_buffer, + 'game_pos_priorities': replay_buffer.game_pos_priorities, + 'game_segment_game_pos_look_up': replay_buffer.game_segment_game_pos_look_up, + 'num_of_collected_episodes': replay_buffer.num_of_collected_episodes, + 'base_idx': replay_buffer.base_idx, + 'clear_time': replay_buffer.clear_time, + } + + # 2. 使用我们的深度净化函数处理整个数据结构 + fully_serializable_data = deep_to_serializable(buffer_data_to_save_raw) + + # 3. 保存这个100%纯净的字典。 + # torch.save(fully_serializable_data, file_path) + # logging.info(f"Game Buffer 纯数据已成功保存至: {file_path}") + + # 健壮的保存逻辑 + temp_file_path = file_path + ".tmp" + try: + torch.save(fully_serializable_data, temp_file_path) + # 在某些文件系统上,可以强制同步到磁盘 + # os.sync() + os.rename(temp_file_path, file_path) + logging.info(f"Game Buffer 纯数据已成功保存至: {file_path}") + except Exception as e: + logging.error(f"保存失败: {e}") + if os.path.exists(temp_file_path): + os.remove(temp_file_path) # 清理临时文件 + + + except Exception as e: + logging.error(f"在迭代次数 {current_milestone} 保存 Game Buffer 纯数据失败。错误: {e}", exc_info=True) + + last_save_iter = current_iter + # ============================================================== + # 结束: 保存逻辑 + # ============================================================== + + train_epoch += 1 + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + if cfg.policy.eval_offline: + logging.info(f'eval offline beginning...') + ckpt_dirname = './{}/ckpt'.format(learner.exp_name) + for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list): + ckpt_name = 'iteration_{}.pth.tar'.format(train_iter) + ckpt_path = os.path.join(ckpt_dirname, ckpt_name) + policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device)) + stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) + logging.info( + f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}') + logging.info(f'eval offline finished!') + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_muzero_segment_save_buffer_v0.py b/lzero/entry/train_muzero_segment_save_buffer_v0.py new file mode 100644 index 000000000..5af6a2ef2 --- /dev/null +++ b/lzero/entry/train_muzero_segment_save_buffer_v0.py @@ -0,0 +1,262 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from .utils import random_collect, calculate_update_per_collect + +timer = EasyTimer() + +def train_muzero_segment( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms (with muzero_segment_collector and buffer reanalyze trick), including MuZero, EfficientZero, Sampled MuZero, Sampled EfficientZero, Gumbel MuZero, StochasticMuZero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'" + + GameBuffer = None + if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']: + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_muzero': + from lzero.mcts import SampledMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + if cfg.policy.eval_offline: + cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + batch_size = policy_config.batch_size + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + if cfg.policy.eval_offline: + eval_train_iter_list = [] + eval_train_envstep_list = [] + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + save_buffer_interval = 50000 + save_buffer_interval = 2 + + last_save_iter = 0 + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + if evaluator.should_eval(learner.train_iter): + if cfg.policy.eval_offline: + eval_train_iter_list.append(learner.train_iter) + eval_train_envstep_list.append(collector.envstep) + else: + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + update_per_collect = calculate_update_per_collect(cfg, new_data) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + for i in range(update_per_collect): + if cfg.policy.buffer_reanalyze_freq >= 1: + if i > 0 and i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + # ============================================================== + # 开始: 【已修复】定期保存Game Buffer的核心数据 + # ============================================================== + current_iter = learner.train_iter + if current_iter // save_buffer_interval > last_save_iter // save_buffer_interval: + current_milestone = (current_iter // save_buffer_interval) * save_buffer_interval + + buffer_save_dir = os.path.join(cfg.exp_name, 'game_buffers') + os.makedirs(buffer_save_dir, exist_ok=True) + + file_path = os.path.join(buffer_save_dir, f'muzero_game_buffer_iter_{current_milestone}.pth') + + logging.info(f"达到训练迭代次数 {current_milestone},正在保存 Game Buffer 的核心数据...") + + try: + # 根据您提供的 MuZeroGameBuffer 源代码,精确提取所有核心状态属性 + buffer_data_to_save = { + 'cfg': replay_buffer._cfg, + 'game_segment_buffer': replay_buffer.game_segment_buffer, + 'game_pos_priorities': replay_buffer.game_pos_priorities, + 'game_segment_game_pos_look_up': replay_buffer.game_segment_game_pos_look_up, + 'num_of_collected_episodes': replay_buffer.num_of_collected_episodes, + 'base_idx': replay_buffer.base_idx, + 'clear_time': replay_buffer.clear_time, + 'keep_ratio': replay_buffer.keep_ratio, + 'model_update_interval': replay_buffer.model_update_interval, + } + + torch.save(buffer_data_to_save, file_path) + logging.info(f"Game Buffer 核心数据已成功保存至: {file_path}") + except Exception as e: + logging.error(f"在迭代次数 {current_milestone} 保存 Game Buffer 核心数据失败。错误: {e}") + + last_save_iter = current_iter + # ============================================================== + # 结束: 保存逻辑 + # ============================================================== + + train_epoch += 1 + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + if cfg.policy.eval_offline: + logging.info(f'eval offline beginning...') + ckpt_dirname = './{}/ckpt'.format(learner.exp_name) + for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list): + ckpt_name = 'iteration_{}.pth.tar'.format(train_iter) + ckpt_path = os.path.join(ckpt_dirname, ckpt_name) + policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device)) + stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) + logging.info( + f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}') + logging.info(f'eval offline finished!') + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_segment.py b/lzero/entry/train_unizero_segment.py index c1ed74b16..d925f10de 100644 --- a/lzero/entry/train_unizero_segment.py +++ b/lzero/entry/train_unizero_segment.py @@ -76,7 +76,6 @@ def train_unizero_segment( evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) collector_env.seed(cfg.seed) - # collector_env.seed(cfg.seed, dynamic_seed=False) evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) @@ -88,6 +87,19 @@ def train_unizero_segment( policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) logging.info(f'Loading model from {model_path} end!') + # ==================== 新增修改 ==================== + # 根据用户请求,在加载预训练模型后重新初始化 value head。 + # 这有助于排除 value head 可能陷入饱和区的问题。 + # logging.info("根据请求,正在重新初始化 value head...") + # # 从策略的 learn_mode 访问底层的 world model。 + # # 在 LightZero 的结构中,这通常是 `policy.learn_mode._model`。 + # if hasattr(policy.learn_mode.get_attribute("learn_model").world_model, 'reinit_prediction_heads'): + # policy.learn_mode.get_attribute("learn_model").world_model.reinit_prediction_heads(heads_to_reinit= ['value']) + # logging.info("Value head 已成功重新初始化。") + # else: + # logging.warning("未能找到 'reinit_prediction_heads' 方法。请检查模型结构。跳过重新初始化步骤。") + # # ========================================================== + # Create worker components: learner, collector, evaluator, replay buffer, commander tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) @@ -154,6 +166,7 @@ def train_unizero_segment( collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) # Evaluate policy performance + # if learner.train_iter==0 or evaluator.should_eval(learner.train_iter): if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: @@ -183,7 +196,9 @@ def train_unizero_segment( logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') logging.info(f'Buffer reanalyze time: {timer.value}') + # Train the policy if sufficient data is available + # cfg.policy.train_start_after_envsteps=8000 if collector.envstep > cfg.policy.train_start_after_envsteps: if cfg.policy.sample_type == 'episode': data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size @@ -216,6 +231,14 @@ def train_unizero_segment( if cfg.policy.use_priority: replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + # if learner.train_iter % 50000==0: # 50k iter, 10k envsteps # TODO + # if hasattr(policy.learn_mode.get_attribute("learn_model").world_model, 'reinit_prediction_heads'): + # policy.learn_mode.get_attribute("learn_model").world_model.reinit_prediction_heads(heads_to_reinit= ['value',"reward","policy"]) + # logging.info("Value/reward policy head 已成功重新初始化。") + # else: + # logging.warning("未能找到 'reinit_prediction_heads' 方法。请检查模型结构。跳过重新初始化步骤。") + train_epoch += 1 policy.recompute_pos_emb_diff_and_clear_cache() diff --git a/lzero/entry/train_unizero_segment_from_buffer.py b/lzero/entry/train_unizero_segment_from_buffer.py new file mode 100644 index 000000000..9d3ad5f58 --- /dev/null +++ b/lzero/entry/train_unizero_segment_from_buffer.py @@ -0,0 +1,297 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional + +import torch +import wandb +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from .utils import random_collect, calculate_update_per_collect +from lzero.mcts import GameSegment +import numpy as np +timer = EasyTimer() + +def train_unizero_segment_from_buffer( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + expert_buffer_path: Optional[str] = None, # <--- 新增参数 +) -> 'Policy': + """ + Overview: + The train entry for UniZero (with muzero_segment_collector and buffer reanalyze trick), proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" + + # Import the correct GameBuffer class based on the policy type + game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} + + GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), + game_buffer_classes[create_cfg.policy.type]) + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + # collector_env.seed(cfg.seed, dynamic_seed=False) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create worker components: learner, collector, evaluator, replay buffer, commander + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # MCTS+RL algorithms related core code + policy_config = cfg.policy + + # TODO(pu) + # ============================================================== + # 开始: 【终极修复】从纯数据加载专家Game Buffer + # ============================================================== + # 1. 首先,创建一个空的UniZero Buffer实例 + # 这里的 GameBuffer 是在 train_unizero_segment 开头定义的 UniZeroGameBuffer + replay_buffer = GameBuffer(policy_config) + + if expert_buffer_path and os.path.exists(expert_buffer_path): + logging.info(f"正在从 {expert_buffer_path} 加载专家 Game Buffer 纯数据...") + try: + # import ipdb;ipdb.set_trace() + # 加载保存的字典。使用 weights_only=True 更安全,因为我们知道里面没有可执行代码。 + # 如果这导致问题,可以暂时切换回 False,但 True 是推荐的做法。 + loaded_data = torch.load(expert_buffer_path, map_location=cfg.policy.device, weights_only=False) + + # 2. 从NumPy字典列表重建GameSegment对象列表 + reconstructed_game_segments = [] + for numpy_dict in loaded_data['game_segment_buffer']: + # 创建一个空的GameSegment实例 + # 注意:如果GameSegment的__init__需要参数,需要相应提供 + game_segment = GameSegment(action_space=numpy_dict['action_space_size'], game_segment_length=len(numpy_dict['action_segment']), config=policy_config) + + # 遍历字典,用加载的数据填充实例 + for attr, value in numpy_dict.items(): + if isinstance(value, np.ndarray): + # 将NumPy数组转换回PyTorch张量,并移动到正确的设备 + # setattr(game_segment, attr, torch.from_numpy(value).to(cfg.policy.device)) + # 将NumPy数组转换回PyTorch张量,并将其保留在CPU上 + setattr(game_segment, attr, torch.from_numpy(value)) + else: + setattr(game_segment, attr, value) + reconstructed_game_segments.append(game_segment) + + # 3. 将重建好的数据填充到新的replay_buffer实例中 + replay_buffer.game_segment_buffer = reconstructed_game_segments + replay_buffer.game_pos_priorities = loaded_data.get('game_pos_priorities', []) + replay_buffer.game_segment_game_pos_look_up = loaded_data.get('game_segment_game_pos_look_up', []) + replay_buffer.num_of_collected_episodes = loaded_data.get('num_of_collected_episodes', 0) + replay_buffer.base_idx = loaded_data.get('base_idx', 0) + replay_buffer.clear_time = loaded_data.get('clear_time', 0) + + logging.info(f"专家 Game Buffer 纯数据加载并重建成功。Buffer状态: {replay_buffer}") + + # 可选:进入离线学习模式 + if cfg.policy.get('offline_learn', False): + cfg.policy.random_collect_episode_num = 0 + max_env_step = 0 + logging.info("已配置为离线学习模式,将禁用数据收集。") + + except Exception as e: + logging.error(f"从 {expert_buffer_path} 加载专家Buffer数据失败。将使用新的空Buffer。错误: {e}", exc_info=True) + else: + if expert_buffer_path: + logging.warning(f"提供的专家Buffer路径不存在: {expert_buffer_path}。将使用新的空Buffer。") + # ============================================================== + # 结束: 加载逻辑 + # ============================================================== + + collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, + policy_config=policy_config) + evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) + + # Learner's before_run hook + learner.call_hook('before_run') + + if cfg.policy.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + # Collect random data before training + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + batch_size = policy._cfg.batch_size + + # TODO: for visualize + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + if cfg.policy.multi_gpu: + # Get current world size and rank + world_size = get_world_size() + rank = get_rank() + else: + world_size = 1 + rank = 0 + + while True: + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + # Set temperature for visit count distributions + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + # Configure epsilon for epsilon-greedy exploration + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # Evaluate policy performance + if learner.train_iter ==0 or evaluator.should_eval(learner.train_iter): + # if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect new data + # new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + # update_per_collect = calculate_update_per_collect(cfg, new_data, world_size) + + # update_per_collect = 10000 # TODO + update_per_collect = 1000 # TODO + + # Update replay buffer + # replay_buffer.push_game_segments(new_data) + # replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze buffer + # if cfg.policy.buffer_reanalyze_freq >= 1: + # # Reanalyze buffer times in one train_epoch + # reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + # else: + # # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch + # if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + # with timer: + # # Each reanalyze process will reanalyze sequences ( transitions per sequence) + # replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + # buffer_reanalyze_count += 1 + # logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + # logging.info(f'Buffer reanalyze time: {timer.value}') + + # Train the policy if sufficient data is available + # if collector.envstep > cfg.policy.train_start_after_envsteps: + if cfg.policy.sample_type == 'episode': + data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size + else: + data_sufficient = replay_buffer.get_num_of_transitions() > batch_size + if not data_sufficient: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' + ) + continue + + for i in range(update_per_collect): + # learner._last_iter.add(1) # target net更新需要依靠这个变量 ============ + # if cfg.policy.buffer_reanalyze_freq >= 1: + # # Reanalyze buffer times in one train_epoch + # if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + # with timer: + # # Each reanalyze process will reanalyze sequences ( transitions per sequence) + # replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + # buffer_reanalyze_count += 1 + # logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + # logging.info(f'Buffer reanalyze time: {timer.value}') + + # import ipdb;ipdb.set_trace() + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + train_data.append(learner.train_iter) + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Check stopping criteria + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + if cfg.policy.use_wandb: + wandb.finish() + return policy diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index d7ccb0678..30e24e94b 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -8,3 +8,4 @@ from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer from .game_buffer_rezero_mz import ReZeroMZGameBuffer from .game_buffer_rezero_ez import ReZeroEZGameBuffer +from .game_segment import GameSegment diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 61ba751a9..093fa163d 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -7,7 +7,7 @@ from ding.torch_utils.data_helper import to_list from ding.utils import BUFFER_REGISTRY from easydict import EasyDict - +import datetime if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy @@ -173,109 +173,136 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: return orig_data def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: - """ - Overview: - This function samples a batch of game segments for reanalysis from the replay buffer. - It uses priority sampling based on the `reanalyze_time` of each game segment, with segments - that have been reanalyzed more frequently receiving lower priority. - - The function returns a tuple containing information about the sampled game segments, - including their positions within each segment and the time the batch was created. - Arguments: - - batch_size (:obj:`int`): - The number of samples to draw in this batch. - - Returns: - - Tuple: - A tuple containing the following elements: - - game_segment_list: A list of the sampled game segments. - - pos_in_game_segment_list: A list of indices representing the position of each transition - within its corresponding game segment. - - batch_index_list: The indices of the sampled game segments in the replay buffer. - - make_time: A list of timestamps (set to `0` in this implementation) indicating when - the batch was created. - - Key Details: - 1. **Priority Sampling**: - Game segments are sampled based on a probability distribution calculated using - the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently - are less likely to be selected. - 2. **Segment Slicing**: - Each selected game segment is sampled at regular intervals determined by the - `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled - from each selected segment. - 3. **Handling Extra Samples**: - If the `batch_size` is not perfectly divisible by the number of samples per segment, - additional segments are sampled to make up the difference. - 4. **Reanalyze Time Update**: - The `reanalyze_time` attribute of each sampled game segment is incremented to reflect - that it has been selected for reanalysis again. - Raises: - - ValueError: - If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. - """ - train_sample_num = len(self.game_segment_buffer) - assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." - valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) - - # Calculate the number of samples per segment - samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps - - # Make sure that the batch size can be divided by the number of samples per segment - if samples_per_segment == 0: - raise ValueError("The game segment length is too small for num_unroll_steps.") - - # Calculate the number of samples per segment - batch_size_per_segment = batch_size // samples_per_segment - - # If the batch size cannot be divided, process the remainder part - extra_samples = batch_size % samples_per_segment - - # We use the reanalyze_time in the game_segment_buffer to generate weights - reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) - - # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) - base_decay_rate = 100 - decay_rate = base_decay_rate / valid_sample_num - weights = np.exp(-decay_rate * reanalyze_times) - - # Normalize the weights to a probability distribution - probabilities = weights / np.sum(weights) - - # Sample game segments according to the probabilities - selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, - p=probabilities) - - # If there are extra samples to be allocated, randomly select some game segments and sample again - if extra_samples > 0: - extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=False, p=probabilities) - selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) - - game_segment_list = [] - pos_in_game_segment_list = [] - batch_index_list = [] - - for game_segment_idx in selected_game_segments: - game_segment_idx -= self.base_idx - game_segment = self.game_segment_buffer[game_segment_idx] - - # Update reanalyze_time only once - game_segment.reanalyze_time += 1 - - # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) - for i in range(samples_per_segment): - game_segment_list.append(game_segment) - pos_in_game_segment = i * self._cfg.num_unroll_steps - if pos_in_game_segment >= len(game_segment): - pos_in_game_segment = np.random.choice(len(game_segment), 1).item() - pos_in_game_segment_list.append(pos_in_game_segment) - batch_index_list.append(game_segment_idx) - - # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). - make_time = [0. for _ in range(len(batch_index_list))] - - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) - return orig_data + """ + Overview: + This function samples a batch of game segments for reanalysis from the replay buffer. + It uses priority sampling based on the `reanalyze_time` of each game segment, with segments + that have been reanalyzed more frequently receiving lower priority. + + The function returns a tuple containing information about the sampled game segments, + including their positions within each segment and the time the batch was created. + Arguments: + - batch_size (:obj:`int`): + The number of samples to draw in this batch. + + Returns: + - Tuple: + A tuple containing the following elements: + - game_segment_list: A list of the sampled game segments. + - pos_in_game_segment_list: A list of indices representing the position of each transition + within its corresponding game segment. + - batch_index_list: The indices of the sampled game segments in the replay buffer. + - make_time: A list of timestamps (set to `0` in this implementation) indicating when + the batch was created. + + Key Details: + 1. **Priority Sampling**: + Game segments are sampled based on a probability distribution calculated using + the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently + are less likely to be selected. + 2. **Segment Slicing**: + Each selected game segment is sampled at regular intervals determined by the + `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled + from each selected segment. + 3. **Handling Extra Samples**: + If the `batch_size` is not perfectly divisible by the number of samples per segment, + additional segments are sampled to make up the difference. + 4. **Reanalyze Time Update**: + The `reanalyze_time` attribute of each sampled game segment is incremented to reflect + that it has been selected for reanalysis again. + Raises: + - ValueError: + If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. + """ + train_sample_num = len(self.game_segment_buffer) + assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." + valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) + + # Calculate the number of samples per segment + samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps + + # Make sure that the batch size can be divided by the number of samples per segment + if samples_per_segment == 0: + raise ValueError("The game segment length is too small for num_unroll_steps.") + + # Calculate the number of samples per segment + batch_size_per_segment = batch_size // samples_per_segment + + # If the batch size cannot be divided, process the remainder part + extra_samples = batch_size % samples_per_segment + + # We use the reanalyze_time in the game_segment_buffer to generate weights + reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) + + # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) + base_decay_rate = 100 + # Add a small epsilon to avoid division by zero if valid_sample_num is 0 + decay_rate = base_decay_rate / (valid_sample_num + 1e-6) + weights = np.exp(-decay_rate * reanalyze_times) + + # Normalize the weights to a probability distribution, handle case where sum is zero + sum_weights = np.sum(weights) + if sum_weights > 0: + probabilities = weights / sum_weights + else: + # If all weights are zero, use a uniform distribution + probabilities = np.ones(valid_sample_num) / valid_sample_num + + # Sample game segments according to the probabilities + # Ensure valid_sample_num is not zero before sampling + if valid_sample_num == 0: + return ([], [], [], [], []) + + selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, + p=probabilities) + + # If there are extra samples to be allocated, randomly select some game segments and sample again + if extra_samples > 0: + # We need to handle the case where we might sample the same segment again. + # A simple way is to allow replacement for extra samples or sample from remaining ones. + # For simplicity, let's stick to the original logic but ensure it's safe. + remaining_segments = np.setdiff1d(np.arange(valid_sample_num), selected_game_segments) + if len(remaining_segments) < extra_samples: + # If not enough unique segments left, sample with replacement from all valid segments + extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=True, p=probabilities) + else: + # Sample from the remaining unique segments + remaining_probs = probabilities[remaining_segments] + remaining_probs /= np.sum(remaining_probs) + extra_game_segments = np.random.choice(remaining_segments, extra_samples, replace=False, p=remaining_probs) + + selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) + + game_segment_list = [] + pos_in_game_segment_list = [] + batch_index_list = [] + print(f"selected_game_segments:{selected_game_segments}") + for game_segment_idx in selected_game_segments: + # ========================================================================= + # FIX: The line below is the source of the error and has been removed. + # `game_segment_idx` is already a valid physical index for `game_segment_buffer`. + # game_segment_idx -= self.base_idx + # ========================================================================= + game_segment = self.game_segment_buffer[game_segment_idx] + + # Update reanalyze_time only once + game_segment.reanalyze_time += 1 + + # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) + for i in range(samples_per_segment): + game_segment_list.append(game_segment) + pos_in_game_segment = i * self._cfg.num_unroll_steps + if pos_in_game_segment >= len(game_segment): + pos_in_game_segment = np.random.choice(len(game_segment), 1).item() + pos_in_game_segment_list.append(pos_in_game_segment) + # NOTE: We should append the physical index here, as it corresponds to the sampled segment. + batch_index_list.append(game_segment_idx) + + # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). + make_time = [0. for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: """ @@ -589,52 +616,114 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None: # print(f'num of transitions is {len(self.game_segment_game_pos_look_up)}') def remove_oldest_data_to_fit(self) -> None: - """ - Overview: - remove some oldest data if the replay buffer is full. - """ - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + assert self.replay_buffer_size > self._cfg.batch_size, "Replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: - index = 0 + index = -1 + current_transitions = total_transition for i in range(nums_of_game_segments): - length_data = len(self.game_segment_buffer[i].action_segment) if len(self.game_segment_buffer[i].action_segment)= self._cfg.batch_size: + + if index != -1 and total_transition >= self._cfg.batch_size: self._remove(index + 1) - def _remove(self, excess_game_segment_index: List[int]) -> None: + def _remove(self, excess_game_segment_index: int) -> None: """ Overview: - delete game segments in index [0: excess_game_segment_index] + Delete game segments in index [0: excess_game_segment_index] and log the operation. Arguments: - - excess_game_segment_index (:obj:`List[str]`): Index of data. - """ - excess_game_positions = sum( - [len(game_segment) for game_segment in self.game_segment_buffer[:excess_game_segment_index]] - ) + - excess_game_segment_index (:obj:`int`): The number of game segments to remove from the beginning. + """ + if excess_game_segment_index <= 0: + return + + # --- Start of logging modification --- + + # 1. Gather information BEFORE removal + timestamp = datetime.datetime.now() + base_idx_before = self.base_idx + segments_before = self.get_num_of_game_segments() + transitions_before = self.get_num_of_transitions() + + # Calculate the exact number of transitions to be removed + excess_game_positions = 0 + for i in range(excess_game_segment_index): + try: + length_data = len(self.game_segment_buffer[i].action_segment) + except AttributeError: + length_data = len(self.game_segment_buffer[i]) + excess_game_positions += min(length_data, self._cfg.game_segment_length) + + # 2. Perform the removal operations del self.game_segment_buffer[:excess_game_segment_index] self.game_pos_priorities = self.game_pos_priorities[excess_game_positions:] del self.game_segment_game_pos_look_up[:excess_game_positions] self.base_idx += excess_game_segment_index self.clear_time = time.time() + # 3. Gather information AFTER removal + base_idx_after = self.base_idx + segments_after = self.get_num_of_game_segments() + transitions_after = self.get_num_of_transitions() + + # 4. Format the log message + log_message = ( + f"--- GameBuffer Removal Log ---\n" + f"Timestamp: {timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}\n" + f"Total collected episodes: {self.num_of_collected_episodes}\n" + f"\n" + f"Action: Removing {excess_game_segment_index} oldest game segment(s).\n" + f" This corresponds to removing {excess_game_positions} transitions.\n" + f"\n" + f"State BEFORE removal:\n" + f" - Segments: {segments_before}\n" + f" - Transitions: {transitions_before}\n" + f" - base_idx: {base_idx_before}\n" + f"\n" + f"State AFTER removal:\n" + f" - Segments: {segments_after}\n" + f" - Transitions: {transitions_after}\n" + f" - base_idx: {base_idx_after}\n" + f"------------------------------\n\n" + ) + # TODO + # 5. Print to console and write to file + # print(log_message) + + # log_filename = f"game_buffer_remove_log_{timestamp.strftime('%Y%m%d_%H%M%S')}.txt" + # try: + # with open(log_filename, 'a', encoding='utf-8') as f: + # f.write(log_message) + # except Exception as e: + # print(f"[ERROR] Failed to write to log file {log_filename}: {e}") + + # --- End of logging modification --- + def get_num_of_episodes(self) -> int: - # number of collected episodes return self.num_of_collected_episodes def get_num_of_game_segments(self) -> int: - # num of game segments return len(self.game_segment_buffer) def get_num_of_transitions(self) -> int: - # total number of transitions return len(self.game_segment_game_pos_look_up) def __repr__(self): - return f'current buffer statistics is: num_of_all_collected_episodes: {self.num_of_collected_episodes}, num of game segments: {len(self.game_segment_buffer)}, number of transitions: {len(self.game_segment_game_pos_look_up)}' + return ( + f'GameBuffer Statistics:\n' + f' - All collected episodes: {self.num_of_collected_episodes}\n' + f' - Current game segments: {self.get_num_of_game_segments()}\n' + f' - Current transitions: {self.get_num_of_transitions()}\n' + f' - base_idx (offset): {self.base_idx}' + ) \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer_bkp20250818.py b/lzero/mcts/buffer/game_buffer_bkp20250818.py new file mode 100644 index 000000000..decc8bda8 --- /dev/null +++ b/lzero/mcts/buffer/game_buffer_bkp20250818.py @@ -0,0 +1,668 @@ +import copy +import time +from abc import ABC, abstractmethod +from typing import Any, List, Tuple, Optional, Union, TYPE_CHECKING + +import numpy as np +from ding.torch_utils.data_helper import to_list +from ding.utils import BUFFER_REGISTRY +from easydict import EasyDict + +if TYPE_CHECKING: + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy + + +@BUFFER_REGISTRY.register('game_buffer') +class GameBuffer(ABC, object): + """ + Overview: + The base game buffer class for MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy. + """ + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + # Default configuration for GameBuffer. + config = dict( + # (int) The size/capacity of the replay buffer in terms of transitions. + replay_buffer_size=int(1e6), + # (float) The ratio of experiences required for the reanalyzing part in a minibatch. + reanalyze_ratio=0, + # (bool) Whether to consider outdated experiences for reanalyzing. If True, we first sort the data in the minibatch by the time it was produced + # and only reanalyze the oldest ``reanalyze_ratio`` fraction. + reanalyze_outdated=True, + # (bool) Whether to use the root value in the reanalyzing part. Please refer to EfficientZero paper for details. + use_root_value=False, + # (int) The number of samples required for mini inference. + mini_infer_size=10240, + # (str) The type of sampled data. The default is 'transition'. Options: 'transition', 'episode'. + sample_type='transition', + ) + + def __init__(self, cfg: dict): + super().__init__() + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + self._cfg = cfg + assert self._cfg.env_type in ['not_board_games', 'board_games'] + assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space'] + + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + + self.keep_ratio = 1 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + @abstractmethod + def sample( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + + @abstractmethod + def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: + """ + Overview: + prepare the context of a batch + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + orig_data: Any batch context from replay buffer + reanalyze_ratio: float ratio of reanalyzed policy (value is 100% reanalyzed) + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + pass + + def _sample_orig_data(self, batch_size: int) -> Tuple: + """ + Overview: + sample orig_data that contains: + game_segment_list: a list of game segments + pos_in_game_segment_list: transition index in game (relative index) + batch_index_list: the index of start transition of sampled minibatch in replay buffer + weights_list: the weight concerning the priority + make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Arguments: + - batch_size (:obj:`int`): batch size + - beta: float the parameter in PER for calculating the priority + """ + assert self._beta > 0 + num_of_transitions = self.get_num_of_transitions() + if self._cfg.use_priority is False: + self.game_pos_priorities = np.ones_like(self.game_pos_priorities) + + # +1e-6 for numerical stability + probs = self.game_pos_priorities ** self._alpha + 1e-6 + probs /= probs.sum() + + # sample according to transition index + batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) + + if self._cfg.reanalyze_outdated is True: + # NOTE: used in reanalyze part + batch_index_list.sort() + + weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) + weights_list /= weights_list.max() + + game_segment_list = [] + pos_in_game_segment_list = [] + + for idx in batch_index_list: + game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] + game_segment_idx -= self.base_idx + game_segment = self.game_segment_buffer[game_segment_idx] + + game_segment_list.append(game_segment) + + # print(f'len(game_segment)=:len(game_segment.action_segment): {len(game_segment)}') + # print(f'len(game_segment.obs_segment): {game_segment.obs_segment.shape[0]}') + + # In the reanalysis phase, `pos_in_game_segment` should be a multiple of `num_unroll_steps`. + # Indices exceeding `game_segment_length` are padded with the next segment and are not updated + # in the current implementation. Therefore, we need to sample `pos_in_game_segment` within + # [0, game_segment_length - num_unroll_steps] to avoid padded data. + + if self._cfg.action_type == 'varied_action_space': + # For some environments (e.g., Jericho), the action space size may be different. + # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), + # we avoid sampling from the last `num_unroll_steps` steps of the game segment. + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() + else: + # For environments with a fixed action space (e.g., Atari), + # we can safely sample from the entire game segment range. + if pos_in_game_segment >= self._cfg.game_segment_length: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + + pos_in_game_segment_list.append(pos_in_game_segment) + + + make_time = [time.time() for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + return orig_data + + def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: + """ + Overview: + This function samples a batch of game segments for reanalysis from the replay buffer. + It uses priority sampling based on the `reanalyze_time` of each game segment, with segments + that have been reanalyzed more frequently receiving lower priority. + + The function returns a tuple containing information about the sampled game segments, + including their positions within each segment and the time the batch was created. + Arguments: + - batch_size (:obj:`int`): + The number of samples to draw in this batch. + + Returns: + - Tuple: + A tuple containing the following elements: + - game_segment_list: A list of the sampled game segments. + - pos_in_game_segment_list: A list of indices representing the position of each transition + within its corresponding game segment. + - batch_index_list: The indices of the sampled game segments in the replay buffer. + - make_time: A list of timestamps (set to `0` in this implementation) indicating when + the batch was created. + + Key Details: + 1. **Priority Sampling**: + Game segments are sampled based on a probability distribution calculated using + the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently + are less likely to be selected. + 2. **Segment Slicing**: + Each selected game segment is sampled at regular intervals determined by the + `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled + from each selected segment. + 3. **Handling Extra Samples**: + If the `batch_size` is not perfectly divisible by the number of samples per segment, + additional segments are sampled to make up the difference. + 4. **Reanalyze Time Update**: + The `reanalyze_time` attribute of each sampled game segment is incremented to reflect + that it has been selected for reanalysis again. + Raises: + - ValueError: + If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. + """ + train_sample_num = len(self.game_segment_buffer) + assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." + valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) + + # Calculate the number of samples per segment + samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps + + # Make sure that the batch size can be divided by the number of samples per segment + if samples_per_segment == 0: + raise ValueError("The game segment length is too small for num_unroll_steps.") + + # Calculate the number of samples per segment + batch_size_per_segment = batch_size // samples_per_segment + + # If the batch size cannot be divided, process the remainder part + extra_samples = batch_size % samples_per_segment + + # We use the reanalyze_time in the game_segment_buffer to generate weights + reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) + + # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) + base_decay_rate = 100 + # Add a small epsilon to avoid division by zero if valid_sample_num is 0 + decay_rate = base_decay_rate / (valid_sample_num + 1e-6) + weights = np.exp(-decay_rate * reanalyze_times) + + # Normalize the weights to a probability distribution, handle case where sum is zero + sum_weights = np.sum(weights) + if sum_weights > 0: + probabilities = weights / sum_weights + else: + # If all weights are zero, use a uniform distribution + probabilities = np.ones(valid_sample_num) / valid_sample_num + + # Sample game segments according to the probabilities + # Ensure valid_sample_num is not zero before sampling + if valid_sample_num == 0: + return ([], [], [], [], []) + + selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, + p=probabilities) + + # If there are extra samples to be allocated, randomly select some game segments and sample again + if extra_samples > 0: + # We need to handle the case where we might sample the same segment again. + # A simple way is to allow replacement for extra samples or sample from remaining ones. + # For simplicity, let's stick to the original logic but ensure it's safe. + remaining_segments = np.setdiff1d(np.arange(valid_sample_num), selected_game_segments) + if len(remaining_segments) < extra_samples: + # If not enough unique segments left, sample with replacement from all valid segments + extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=True, p=probabilities) + else: + # Sample from the remaining unique segments + remaining_probs = probabilities[remaining_segments] + remaining_probs /= np.sum(remaining_probs) + extra_game_segments = np.random.choice(remaining_segments, extra_samples, replace=False, p=remaining_probs) + + selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) + + game_segment_list = [] + pos_in_game_segment_list = [] + batch_index_list = [] + + for game_segment_idx in selected_game_segments: + # ========================================================================= + # FIX: The line below is the source of the error and has been removed. + # `game_segment_idx` is already a valid physical index for `game_segment_buffer`. + # game_segment_idx -= self.base_idx + # ========================================================================= + game_segment = self.game_segment_buffer[game_segment_idx] + + # Update reanalyze_time only once + game_segment.reanalyze_time += 1 + + # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) + for i in range(samples_per_segment): + game_segment_list.append(game_segment) + pos_in_game_segment = i * self._cfg.num_unroll_steps + if pos_in_game_segment >= len(game_segment): + pos_in_game_segment = np.random.choice(len(game_segment), 1).item() + pos_in_game_segment_list.append(pos_in_game_segment) + # NOTE: We should append the physical index here, as it corresponds to the sampled segment. + batch_index_list.append(game_segment_idx) + + # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). + make_time = [0. for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data + + def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: + """ + Overview: + sample orig_data that contains: + game_segment_list: a list of game segments + pos_in_game_segment_list: transition index in game (relative index) + batch_index_list: the index of start transition of sampled minibatch in replay buffer + weights_list: the weight concerning the priority + make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Arguments: + - batch_size (:obj:`int`): batch size + - beta: float the parameter in PER for calculating the priority + """ + segment_length = (self.get_num_of_transitions()//2000) + assert self._beta > 0 + num_of_transitions = self.get_num_of_transitions() + sample_points = num_of_transitions // segment_length + + batch_index_list = np.random.choice(2000, batch_size, replace=False) + + if self._cfg.reanalyze_outdated is True: + # NOTE: used in reanalyze part + batch_index_list.sort() + + # TODO(xcy): use weighted sample + game_segment_list = [] + pos_in_game_segment_list = [] + + for idx in batch_index_list: + game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx*segment_length] + game_segment_idx -= self.base_idx + game_segment = self.game_segment_buffer[game_segment_idx] + + game_segment_list.append(game_segment) + pos_in_game_segment_list.append(pos_in_game_segment) + + make_time = [time.time() for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data + + def _sample_orig_data_episode(self, batch_size: int) -> Tuple: + """ + Overview: + Sample original data for a training batch, which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Indices of transitions within the game segments. + - batch_index_list: Indices of the start transitions of the sampled mini-batch in the replay buffer. + - weights_list: Weights for each sampled transition, used for prioritization. + - make_time: Timestamps indicating when the batch was created (useful for managing replay buffer updates). + Arguments: + - batch_size (:obj:`int`): The number of samples to draw for the batch. + - beta (:obj:`float`): Parameter for Prioritized Experience Replay (PER) that adjusts the importance of samples. + """ + assert self._beta > 0, "Beta must be greater than zero." + + num_of_transitions = self.get_num_of_transitions() + + if not self._cfg.use_priority: + self.game_pos_priorities = np.ones_like(self.game_pos_priorities) + + # Add a small constant for numerical stability + probs = self.game_pos_priorities ** self._alpha + 1e-6 + probs /= probs.sum() + + # Sample game segment indices + num_of_game_segments = self.get_num_of_game_segments() + batch_episode_index_list = np.random.choice(num_of_game_segments, batch_size, replace=False) + + if self._cfg.reanalyze_outdated: + # Sort for consistency when reanalyzing + batch_episode_index_list.sort() + + batch_index_list = batch_episode_index_list * self._cfg.game_segment_length + + # Calculate weights for the sampled transitions + weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) + weights_list /= weights_list.max() + + game_segment_list = [] + pos_in_game_segment_list = [] + + # Collect game segments and their initial positions + for episode_index in batch_episode_index_list: + game_segment = self.game_segment_buffer[episode_index] + game_segment_list.append(game_segment) + pos_in_game_segment_list.append(0) # Starting position in game segments + + # Record the time when the batch is created + make_time = [time.time() for _ in range(len(batch_episode_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + return orig_data + + def _preprocess_to_play_and_action_mask( + self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps = None + ): + """ + Overview: + prepare the to_play and action_mask for the target obs in ``value_obs_list`` + - to_play: {list: game_segment_batch_size * (num_unroll_steps+1)} + - action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)} + """ + unroll_steps = unroll_steps if unroll_steps is not None else self._cfg.num_unroll_steps + + to_play = [] + for bs in range(game_segment_batch_size): + to_play_tmp = list( + to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + + unroll_steps + 1] + ) + if len(to_play_tmp) < unroll_steps + 1: + # NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1 + to_play_tmp += [-1 for _ in range(unroll_steps + 1 - len(to_play_tmp))] + to_play.append(to_play_tmp) + to_play = sum(to_play, []) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + return to_play, None + + action_mask = [] + for bs in range(game_segment_batch_size): + action_mask_tmp = list( + action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + + unroll_steps + 1] + ) + if len(action_mask_tmp) < unroll_steps + 1: + action_mask_tmp += [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) + for _ in range(unroll_steps + 1 - len(action_mask_tmp)) + ] + action_mask.append(action_mask_tmp) + action_mask = to_list(action_mask) + action_mask = sum(action_mask, []) + + return to_play, action_mask + + @abstractmethod + def _prepare_reward_value_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], + total_transitions: int + ) -> List[Any]: + """ + Overview: + prepare the context of rewards and values for calculating TD value target in reanalyzing part. + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment + - total_transitions (:obj:`int`): number of collected transitions + Returns: + - reward_value_context (:obj:`list`): value_obs_lst, value_mask, state_index_lst, rewards_lst, game_segment_lens, + td_steps_lst, action_mask_segment, to_play_segment + """ + pass + + @abstractmethod + def _prepare_policy_non_reanalyzed_context( + self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list transition index in game + Returns: + - policy_non_re_context (:obj:`list`): state_index_lst, child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + pass + + @abstractmethod + def _prepare_policy_reanalyzed_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in reanalyzing part. + Arguments: + - batch_index_list (:obj:'list'): start transition index in the replay buffer + - game_segment_list (:obj:'list'): list of game segments + - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history + Returns: + - policy_re_context (:obj:`list`): policy_obs_lst, policy_mask, state_index_lst, indices, + child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + pass + + @abstractmethod + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + pass + + @abstractmethod + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: + """ + Overview: + prepare policy targets from the reanalyzed context of policies + Arguments: + - policy_re_context (:obj:`List`): List of policy context to reanalyzed + Returns: + - batch_target_policies_re + """ + pass + + @abstractmethod + def _compute_target_policy_non_reanalyzed( + self, policy_non_re_context: List[Any], policy_shape: Optional[int] + ) -> np.ndarray: + """ + Overview: + prepare policy targets from the non-reanalyzed context of policies + Arguments: + - policy_non_re_context (:obj:`List`): List containing: + - pos_in_game_segment_list + - child_visits + - game_segment_lens + - action_mask_segment + - to_play_segment + Returns: + - batch_target_policies_non_re + """ + pass + + @abstractmethod + def update_priority( + self, train_data: Optional[List[Optional[np.ndarray]]], batch_priorities: Optional[Any] + ) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. + - batch_priorities (:obj:`batch_priorities`): priorities to update to. + """ + pass + + def push_game_segments(self, data_and_meta: Any) -> None: + """ + Overview: + Push game_segments data and it's meta information into buffer. + Save a game segment + Arguments: + - data_and_meta + - data (:obj:`Any`): The data (game segments) which will be pushed into buffer. + - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness. + """ + data, meta = data_and_meta + for (data_game, meta_game) in zip(data, meta): + self._push_game_segment(data_game, meta_game) + + def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None: + """ + Overview: + Push data and it's meta information in buffer. + Save a game segment. + Arguments: + - data (:obj:`Any`): The data (a game segment) which will be pushed into buffer. + - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness. + - done (:obj:`bool`): whether the game is finished. + - unroll_plus_td_steps (:obj:`int`): if the game is not finished, we only save the transitions that can be computed + - priorities (:obj:`list`): the priorities corresponding to the transitions in the game history + Returns: + - buffered_data (:obj:`BufferedData`): The pushed data. + """ + try: + data_length = len(data.action_segment) if len(data.action_segment) < self._cfg.game_segment_length else self._cfg.game_segment_length + except Exception as e: + # to be compatible with unittest + print(e) + data_length = len(data) + + if meta['done']: + self.num_of_collected_episodes += 1 + valid_len = data_length + else: + valid_len = data_length - meta['unroll_plus_td_steps'] + # print(f'valid_len is {valid_len}') + + if meta['priorities'] is None: + if self.game_segment_buffer: + max_prio = self.game_pos_priorities.max() if len(self.game_pos_priorities) > 0 else 1 + else: + max_prio = 1 + + # if no 'priorities' provided, set the valid part of the new-added game history the max_prio + self.game_pos_priorities = np.concatenate((self.game_pos_priorities, [max_prio for _ in range(valid_len)] + [0. for _ in range(valid_len, data_length)])) + else: + assert data_length == len(meta['priorities']), " priorities should be of same length as the game steps" + priorities = meta['priorities'].copy().reshape(-1) + priorities[valid_len:data_length] = 0. + self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities)) + + self.game_segment_buffer.append(data) + self.game_segment_game_pos_look_up += [ + (self.base_idx + len(self.game_segment_buffer) - 1, step_pos) for step_pos in range(data_length) + ] + # print(f'potioritys is {self.game_pos_priorities}') + # print(f'num of transitions is {len(self.game_segment_game_pos_look_up)}') + + def remove_oldest_data_to_fit(self) -> None: + """ + Overview: + remove some oldest data if the replay buffer is full. + """ + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + nums_of_game_segments = self.get_num_of_game_segments() + total_transition = self.get_num_of_transitions() + if total_transition > self.replay_buffer_size: + index = 0 + for i in range(nums_of_game_segments): + length_data = len(self.game_segment_buffer[i].action_segment) if len(self.game_segment_buffer[i].action_segment)= self._cfg.batch_size: + self._remove(index + 1) + + def _remove(self, excess_game_segment_index: List[int]) -> None: + """ + Overview: + delete game segments in index [0: excess_game_segment_index] + Arguments: + - excess_game_segment_index (:obj:`List[str]`): Index of data. + """ + excess_game_positions = sum( + [len(game_segment) for game_segment in self.game_segment_buffer[:excess_game_segment_index]] + ) + del self.game_segment_buffer[:excess_game_segment_index] + self.game_pos_priorities = self.game_pos_priorities[excess_game_positions:] + del self.game_segment_game_pos_look_up[:excess_game_positions] + self.base_idx += excess_game_segment_index + print(f"self.base_idx: {self.base_idx} ") + self.clear_time = time.time() + + def get_num_of_episodes(self) -> int: + # number of collected episodes + return self.num_of_collected_episodes + + def get_num_of_game_segments(self) -> int: + # num of game segments + return len(self.game_segment_buffer) + + def get_num_of_transitions(self) -> int: + # total number of transitions + return len(self.game_segment_game_pos_look_up) + + def __repr__(self): + return f'current buffer statistics is: num_of_all_collected_episodes: {self.num_of_collected_episodes}, num of game segments: {len(self.game_segment_buffer)}, number of transitions: {len(self.game_segment_game_pos_look_up)}' diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index faf0155a0..f02f2332a 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -765,7 +765,60 @@ def _compute_target_policy_non_reanalyzed( policy_index += 1 batch_target_policies_non_re.append(target_policies) - batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) + + # batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) + + # =================== TODO =================== + # ============================================================================== + # 最终修复:应用“黄金模式”来处理异构的目标策略列表 + # ============================================================================== + # 1. 逐元素统一化 (Element-wise Normalization) + # `batch_target_policies_non_re` 是一个列表的列表,例如 [[p1, p2], [p3, p4], ...]。 + # 我们需要确保每个内部的策略 pN 都是一个 Tensor。 + normalized_policy_list = [] + for policy_sequence in batch_target_policies_non_re: + # `policy_sequence` 是一个包含11个策略的列表 + sequence_as_tensors = [] + for policy_step in policy_sequence: + if isinstance(policy_step, torch.Tensor): + sequence_as_tensors.append(policy_step) + elif isinstance(policy_step, np.ndarray): + sequence_as_tensors.append(torch.from_numpy(policy_step)) + else: + # 如果策略是其他格式(例如 list),也尝试转换 + try: + sequence_as_tensors.append(torch.from_numpy(np.array(policy_step))) + except (TypeError, ValueError) as e: + raise TypeError(f"Unsupported policy step type '{type(policy_step)}' in sequence.") from e + + # 2. 将统一化的序列堆叠成一个单一的 Tensor + # 例如,将 11 个 (18,) 的 Tensor 堆叠成一个 (11, 18) 的 Tensor + try: + stacked_sequence = torch.stack(sequence_as_tensors, dim=0) + normalized_policy_list.append(stacked_sequence) + except Exception as e: + # 如果堆叠失败,说明序列内部的策略形状不一致 + print("FATAL: torch.stack failed for a policy sequence. This indicates shape mismatch within a sequence.") + from collections import Counter + shape_counts = Counter(p.shape for p in sequence_as_tensors) + print(f"Shape distribution in the problematic sequence: {shape_counts}") + raise e + + # 3. 将所有序列的 Tensor 列表最终堆叠成一个批次 Tensor + # 例如,将 256 个 (11, 18) 的 Tensor 堆叠成 (256, 11, 18) + try: + final_stacked_tensor = torch.stack(normalized_policy_list, dim=0) + except Exception as e: + print("FATAL: Final torch.stack failed for the batch. This indicates sequence-level shape mismatch.") + from collections import Counter + shape_counts = Counter(p.shape for p in normalized_policy_list) + print(f"Shape distribution of stacked sequences: {shape_counts}") + raise e + + # 4. 转换为 NumPy 数组 + batch_target_policies_non_re = final_stacked_tensor.numpy() + # ============================================================================== + return batch_target_policies_non_re def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: @@ -786,3 +839,4 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) - if metas['make_time'][i] > self.clear_time: idx, prio = indices[i], metas['batch_priorities'][i] self.game_pos_priorities[idx] = prio + diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index b8998acb9..b8fea157c 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -438,15 +438,15 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + # if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self.value_support), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -630,3 +630,34 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values = np.asarray(batch_target_values) return batch_rewards, batch_target_values + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`List[np.ndarray]`): training data to be updated priority. + - batch_priorities (:obj:`np.ndarray`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + """ + # TODO: NOTE: -4 is batch_index_list + indices = train_data[0][-4] + metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + # ==================== START OF FINAL FIX ==================== + + # FIX 1: Handle ValueError by using the first timestamp of the segment for comparison. + first_transition_time = metas['make_time'][i][0] + + if first_transition_time > self.clear_time: + # FIX 2: Handle IndexError by converting the float index to an integer before use. + idx = int(indices[i]) + prio = metas['batch_priorities'][i] + + # Now, idx is a valid integer index. + self.game_pos_priorities[idx] = prio + + # ===================== END OF FINAL FIX ===================== \ No newline at end of file diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index ad216d196..2aca4820a 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -5,7 +5,7 @@ from easydict import EasyDict from ding.utils.compression_helper import jpeg_data_decompressor - +import torch class GameSegment: """ @@ -83,6 +83,8 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.reanalyze_time = 0 + # lzero/mcts/game_segment.py + def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: """ Overview: @@ -93,14 +95,64 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. """ stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] + + # 检查 stacked_obs 是否为 PyTorch 张量,以决定使用 torch 还是 numpy 操作 + is_tensor = isinstance(stacked_obs, torch.Tensor) + if padding: pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) if pad_len > 0: - pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) - stacked_obs = np.concatenate((stacked_obs, pad_frames)) + last_frame = stacked_obs[-1] + if is_tensor: + # 使用 PyTorch 的方式进行 padding + # 1. 在第0维增加一个维度 (e.g., [C, H, W] -> [1, C, H, W]) + # 2. 使用 repeat 复制 pad_len 次 + pad_frames = last_frame.unsqueeze(0).repeat(pad_len, 1, 1, 1) + # 3. 使用 torch.cat 进行拼接 + stacked_obs = torch.cat((stacked_obs, pad_frames), dim=0) + else: + # 保持原有的 NumPy 逻辑 + pad_frames = np.array([last_frame for _ in range(pad_len)]) + stacked_obs = np.concatenate((stacked_obs, pad_frames)) + if self.transform2string: - stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] - return stacked_obs + # 如果是张量,需要先转为numpy + if is_tensor: + # .cpu() 是为了确保数据在CPU上,.numpy() 需要CPU数据 + stacked_obs_np = stacked_obs.cpu().numpy() + else: + stacked_obs_np = stacked_obs + stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs_np] + # 注意:经过 decompressor 后,stacked_obs 变成了 list of numpy arrays + # 为了与函数签名 -> np.ndarray 保持一致,最好再堆叠起来 + return np.stack(stacked_obs, axis=0) + + # 如果函数签名要求返回 np.ndarray,在最后进行转换 + # 这样可以确保内部操作高效,同时不破坏外部接口的约定 + if is_tensor: + return stacked_obs.cpu().numpy() + else: + # 如果原本就是numpy,直接返回 + return stacked_obs + + # def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + # """ + # Overview: + # Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. + # Arguments: + # - timestep (int): The time step. + # - num_unroll_steps (int): The extra length of the observation frames. + # - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. + # """ + # stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] + # if padding: + # pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) + # if pad_len > 0: + # pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) + # stacked_obs = np.concatenate((stacked_obs, pad_frames)) + # if self.transform2string: + # stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] + # return stacked_obs def zero_obs(self) -> List: """ diff --git a/lzero/mcts/utils.py b/lzero/mcts/utils.py index 407f5e2ba..7d16dcf81 100644 --- a/lzero/mcts/utils.py +++ b/lzero/mcts/utils.py @@ -4,7 +4,7 @@ import numpy as np from graphviz import Digraph - +import torch def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int, reshape=False): @@ -98,7 +98,61 @@ def prepare_observation(observation_list, model_type='conv'): - np.ndarray: Reshaped array of observations. """ assert model_type in ['conv', 'mlp', 'conv_context', 'mlp_context'], "model_type must be either 'conv' or 'mlp'" - observation_array = np.array(observation_list) + + # import ipdb;ipdb.set_trace() + # observation_array = np.array(observation_list) + + + # ============================================================================== + # 阶段 1: 逐元素统一化 (Element-wise Normalization) + # 目标:遍历异构列表,根据每个元素自身的类型将其转换为 Tensor。 + # 这是修复此问题的核心逻辑。 + # ============================================================================== + + normalized_list = [] + for i, obs in enumerate(observation_list): + if isinstance(obs, torch.Tensor): + # 元素已经是 Tensor,直接添加 + normalized_list.append(obs) + elif isinstance(obs, list): + # 元素是 list (通常是 list of np.ndarray from collector) + # 使用 np.array 将其转换为单个 ndarray,然后转为 Tensor + try: + normalized_list.append(torch.from_numpy(np.array(obs))) + except ValueError as e: + # 添加更详细的错误信息,以防 list 内部也不均匀 + print(f"FATAL: Failed to convert a list element at index {i} to a numpy array. The list might be inhomogeneous itself.") + print(f"Content of the problematic list: {obs}") + raise e + elif isinstance(obs, np.ndarray): + # 元素是 ndarray,直接转换为 Tensor + normalized_list.append(torch.from_numpy(obs)) + else: + # 捕获任何未预料到的格式 + raise TypeError(f"Unsupported data type '{type(obs)}' found in observation_list at index {i}") + + # ============================================================================== + # 阶段 2: 堆叠与转换 (Stack & Convert) + # 此时,`normalized_list` 保证是一个纯粹的 "list of torch.Tensors"。 + # ============================================================================== + try: + # 使用 PyTorch 自家的、最可靠的 stack 函数 + stacked_tensor = torch.stack(normalized_list, dim=0) + except Exception as e: + # 如果这里仍然失败,几乎可以肯定是形状不一致问题 + print("FATAL: torch.stack failed after element-wise normalization. This indicates a definite shape mismatch.") + from collections import Counter + # 我们现在可以安全地检查形状了 + shape_counts = Counter(t.shape for t in normalized_list) + print(f"Shape distribution in normalized list: {shape_counts}") + raise e + + # 将最终的单一 Tensor 转换为 NumPy 数组 + + # 将最终的单一 Tensor 转换为 NumPy 数组。 + observation_array = stacked_tensor.numpy() + + batch_size = observation_array.shape[0] if model_type in ['conv', 'conv_context']: diff --git a/lzero/model/common.py b/lzero/model/common.py index 795eb72a3..f5b572240 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -22,6 +22,35 @@ from ding.utils import set_pkg_seed, get_rank, get_world_size import torch +import torch +import torch.nn as nn + +torch.hub._validate_not_a_forked_repo=lambda a,b,c: True + + +# 1. 将 L2-Norm 封装成一个 nn.Module 类 +class L2Norm(nn.Module): + """ + 对输入的最后一个维度进行 L2 归一化。 + """ + def __init__(self, eps=1e-6): + """ + 初始化 L2Norm 模块。 + + :param eps: 一个小的浮点数,用于防止在归一化时除以零。 + """ + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + 前向传播函数。 + + :param x: 输入张量。 + :return: 经过 L2 归一化后的张量。 + """ + return F.normalize(x, p=2, dim=-1, eps=self.eps) + def MLP_V2( in_channels: int, hidden_channels: List[int], @@ -457,6 +486,113 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: return cls_embedding +# ============================================================================= +# 新的、可无缝替换的 DINOv2 表征网络 +# ============================================================================= +class DinoV2RepresentationNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (3, 64, 64), + embedding_dim: int = 256, + final_norm_option_in_encoder: str = 'LayerNorm', + group_size: int = 8, + # 下面的参数是为了接口兼容性,DINOv2 模型本身不会使用它们 + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: str = 'BN', + dinov2_model_name: str = "dinov2_vits14", + dinov2_feature_key: str = "x_norm_clstoken", + ) -> None: + """ + Overview: + 一个使用 DINOv2 作为骨干的表征网络,旨在无缝替换 RepresentationNetworkUniZero。 + + Arguments: + - observation_shape: 输入观测的形状。 + - embedding_dim: 最终输出的潜在状态维度。 + - final_norm_option_in_encoder: 编码器最后一层的归一化选项。 + - group_size: SimNorm 使用的组大小。 + - dinov2_model_name: 要加载的 DINOv2 模型名称。 + - dinov2_feature_key: 从 DINOv2 中提取哪种特征。 + - 其他参数: 为了与 UniZeroModel 的调用签名保持一致。 + """ + super().__init__() + self.observation_shape = observation_shape + self.embedding_dim = embedding_dim + + # 1. 加载 DINOv2 基础模型 + print(f"Loading DINOv2 model: {dinov2_model_name}") + self.pretrained_model = torch.hub.load("facebookresearch/dinov2", dinov2_model_name) + self.feature_key = dinov2_feature_key + dinov2_output_dim = self.pretrained_model.num_features + + # DINOv2 模型期望的输入尺寸 (patch_size=14, 官方推荐 518x518) + # self.dinov2_input_size = (518, 518) + # self.dinov2_input_size = (224, 224) + # self.dinov2_input_size = (64, 64) + self.dinov2_input_size = (70, 70) + + # 2. 添加线性投影层以匹配所需的 embedding_dim + # 如果 DINOv2 输出维度和期望的 embedding_dim 不一致,则需要投影 + if dinov2_output_dim != self.embedding_dim: + self.projection = nn.Linear(dinov2_output_dim, self.embedding_dim, bias=False) + print(f"Added projection layer: {dinov2_output_dim} -> {self.embedding_dim}") + else: + self.projection = nn.Identity() + + # 3. 复制 RepresentationNetworkUniZero 中的 final_norm 逻辑 + self.final_norm_option_in_encoder = final_norm_option_in_encoder + if self.final_norm_option_in_encoder in ['LayerNorm', 'LayerNorm_Tanh']: + self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'LayerNormNoAffine': + self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5, elementwise_affine=False) + elif self.final_norm_option_in_encoder == 'SimNorm': + self.final_norm = SimNorm(simnorm_dim=group_size) + elif self.final_norm_option_in_encoder == 'L2Norm': + self.final_norm = L2Norm(eps=1e-6) + elif self.final_norm_option_in_encoder is None: + self.final_norm = nn.Identity() + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + + print(f"Using final normalization: {self.final_norm_option_in_encoder}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x: (B, C_in, H, W) + - output: (B, embedding_dim) + """ + # 4. 在 forward 中动态调整输入尺寸 + # 如果输入尺寸与 DINOv2 期望的不符,进行 resize + if x.shape[-2:] != self.dinov2_input_size: + x = F.interpolate(x, size=self.dinov2_input_size, mode='bicubic', align_corners=False) + + # 提取 DINOv2 特征 + # 使用 no_grad 可以冻结 DINOv2 的权重,只训练后续的层 + # 如果你想微调 DINOv2,请移除 with torch.no_grad() + with torch.no_grad(): + emb = self.pretrained_model.forward_features(x)[self.feature_key] + + # 应用投影层 + x = self.projection(emb) + + # 确保输出形状为 (B, embedding_dim) + if x.dim() != 2: + x = x.view(-1, self.embedding_dim) + + # 5. 应用最终的归一化和激活函数,与原网络保持一致 + if self.final_norm is not None: + x = self.final_norm(x) + + if self.final_norm_option_in_encoder == 'LayerNorm_Tanh': + x = torch.tanh(x) + + return x + class RepresentationNetworkUniZero(nn.Module): def __init__( @@ -529,20 +665,33 @@ def __init__( self.embedding_dim = embedding_dim if self.observation_shape[1] == 64: - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False) elif self.observation_shape[1] in [84, 96]: - self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False) - self.final_norm_option_in_encoder = final_norm_option_in_encoder - if self.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm_option_in_encoder=final_norm_option_in_encoder + # 2. 在 __init__ 中统一初始化 final_norm + if self.final_norm_option_in_encoder in ['LayerNorm', 'LayerNorm_Tanh']: self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'LayerNormNoAffine': + self.final_norm = nn.LayerNorm( + self.embedding_dim, eps=1e-5, elementwise_affine=False + ) elif self.final_norm_option_in_encoder == 'SimNorm': + # 确保 SimNorm 已被定义 self.final_norm = SimNorm(simnorm_dim=group_size) + elif self.final_norm_option_in_encoder == 'L2Norm': + # 直接实例化我们自定义的 L2Norm 模块 + self.final_norm = L2Norm(eps=1e-6) + elif self.final_norm_option_in_encoder is None: + # 如果不需要归一化,可以设置为 nn.Identity() 或 None + self.final_norm = nn.Identity() else: raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: @@ -568,7 +717,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(-1, self.embedding_dim) # NOTE: very important for training stability. - x = self.final_norm(x) + # x = self.final_norm(x) + + # 3. 在 forward 中统一调用 self.final_norm + # 这种结构更加清晰和可扩展 + if self.final_norm is not None: + x = self.final_norm(x) + + # 针对 LayerNorm_Tanh 的特殊处理 + if self.final_norm_option_in_encoder == 'LayerNorm_Tanh': + x = torch.tanh(x) return x diff --git a/lzero/model/entropy_explorer.py b/lzero/model/entropy_explorer.py new file mode 100644 index 000000000..a93efb890 --- /dev/null +++ b/lzero/model/entropy_explorer.py @@ -0,0 +1,186 @@ +import torch +import torch.nn.functional as F +from torch.distributions import Categorical +import torch.optim as optim +import matplotlib.pyplot as plt +import numpy as np +import os + +class EntropyExplorer: + """ + 一个用于探查、计算和生成具有特定熵的离散概率分布的工具。 + 现在支持将可视化结果保存为PNG文件,并解决了中文字体显示问题。 + """ + def __init__(self, action_space_size: int): + """ + 初始化探查器。 + + 参数: + - action_space_size (int): 动作空间的维度 (例如,对于 torch.Size([9]),此值为 9)。 + """ + if not isinstance(action_space_size, int) or action_space_size <= 1: + raise ValueError("action_space_size 必须是大于1的整数。") + + self.action_space_size = action_space_size + self.min_entropy = 0.0 + self.max_entropy = torch.log(torch.tensor(self.action_space_size, dtype=torch.float32)).item() + + print(f"初始化 EntropyExplorer,动作空间大小为: {self.action_space_size}") + print(f"该空间的最小熵为: {self.min_entropy:.4f}") + print(f"该空间的最大熵 (均匀分布) 为: {self.max_entropy:.4f}\n") + + def calculate_entropy(self, policy_tensor: torch.Tensor, is_logits: bool = True) -> float: + """ + 计算给定策略张量的熵。 + """ + if policy_tensor.dim() != 1 or policy_tensor.shape[0] != self.action_space_size: + raise ValueError(f"输入张量的形状必须是 torch.Size([{self.action_space_size}]), 但得到的是 {policy_tensor.shape}") + + if is_logits: + distribution = Categorical(logits=policy_tensor) + else: + if not torch.allclose(policy_tensor.sum(), torch.tensor(1.0), atol=1e-6): + print(f"警告: 输入的概率总和不为1 (总和为: {policy_tensor.sum().item()}),结果可能不准确。") + distribution = Categorical(probs=policy_tensor) + + return distribution.entropy().item() + + def find_distribution_for_entropy(self, target_entropy: float, learning_rate: float = 0.01, num_steps: int = 1500, verbose: bool = False) -> np.ndarray: + """ + 通过优化找到一个具有特定目标熵的概率分布。 + """ + if not (self.min_entropy <= target_entropy <= self.max_entropy): + raise ValueError(f"目标熵 {target_entropy:.4f} 超出有效范围 [{self.min_entropy:.4f}, {self.max_entropy:.4f}]。") + + logits = torch.randn(self.action_space_size, requires_grad=True) + optimizer = optim.Adam([logits], lr=learning_rate) + + print(f"\n正在寻找熵为 {target_entropy:.4f} 的分布...") + for step in range(num_steps): + optimizer.zero_grad() + current_entropy = Categorical(logits=logits).entropy() + loss = (current_entropy - target_entropy).pow(2) + loss.backward() + optimizer.step() + + if verbose and (step % 200 == 0 or step == num_steps - 1): + print(f"步骤 {step+1}/{num_steps} | 当前熵: {current_entropy.item():.4f} | 损失: {loss.item():.6f}") + + final_probs = F.softmax(logits, dim=-1) + final_entropy = self.calculate_entropy(final_probs, is_logits=False) + print(f"优化完成。最终分布的熵为: {final_entropy:.4f}") + + return final_probs.detach().numpy() + + def _set_chinese_font(self): + """ + 尝试设置一个支持中文的字体。 + """ + # 常见的支持中文的字体列表 + font_names = ['SimHei', 'Microsoft YaHei', 'Heiti TC', 'Arial Unicode MS', 'sans-serif'] + try: + plt.rcParams['font.sans-serif'] = font_names + plt.rcParams['axes.unicode_minus'] = False # 解决负号显示为方块的问题 + print("中文字体已设置为 'SimHei' 或其他可用字体。") + except Exception as e: + print(f"设置中文字体失败: {e}。图表中的中文可能无法正常显示。") + + + def visualize_distribution(self, probs: np.ndarray, title: str, save_dir: str = 'entropy_plots', filename: str = None): + """ + 使用条形图可视化概率分布,并将其保存为PNG文件。 + + 参数: + - probs (np.ndarray): 要可视化的概率分布。 + - title (str): 图表标题。 + - save_dir (str): 保存图片的相对路径目录。 + - filename (str): 图片的文件名。如果为None,将不会保存。 + """ + self._set_chinese_font() # 设置中文字体 + + plt.style.use('seaborn-v0_8-whitegrid') + fig, ax = plt.subplots(figsize=(12, 6)) + + actions = np.arange(self.action_space_size) + bars = ax.bar(actions, probs, color='deepskyblue', edgecolor='black', alpha=0.8) + + ax.set_title(title, fontsize=16, weight='bold') + ax.set_xlabel("动作 (Action)", fontsize=12) + ax.set_ylabel("概率 (Probability)", fontsize=12) + ax.set_xticks(actions) + ax.set_xticklabels([f'Action {i}' for i in actions]) + ax.set_ylim(0, max(np.max(probs) * 1.2, 0.2)) # 动态调整Y轴上限 + + for bar in bars: + yval = bar.get_height() + ax.text(bar.get_x() + bar.get_width()/2.0, yval + 0.005, f'{yval:.3f}', ha='center', va='bottom', fontsize=9) + + plt.tight_layout() + + if filename: + # 确保保存目录存在 + if not os.path.exists(save_dir): + os.makedirs(save_dir) + print(f"已创建目录: '{save_dir}/'") + + full_path = os.path.join(save_dir, filename) + + # 保存图片 + plt.savefig(full_path, dpi=300, bbox_inches='tight') + print(f"图表已成功保存到: '{full_path}'") + else: + plt.show() # 如果未提供文件名,则显示图片 + + plt.close(fig) # 关闭图形,释放内存 + + +if __name__ == '__main__': + # --- 使用工具 --- + + # 1. 设置你的动作空间大小 + ACTION_SPACE = 9 + explorer = EntropyExplorer(action_space_size=ACTION_SPACE) + + # 定义保存图片的相对路径 + SAVE_DIRECTORY = "entropy_visualizations" + + # 2. 核心功能:为您的问题复现,生成熵为 2.15 的分布并保存 + target_entropy_high = 2.15 + generated_probs_high = explorer.find_distribution_for_entropy(target_entropy_high) + + # 详细恰当的命名 + filename_high = f"entropy_{target_entropy_high:.4f}_actions_{ACTION_SPACE}.png" + title_high = f"熵约为 {target_entropy_high} 的一个可能分布 (高熵)" + + explorer.visualize_distribution( + generated_probs_high, + title=title_high, + save_dir=SAVE_DIRECTORY, + filename=filename_high + ) + + # 3. 更多示例,以建立直观感受 + + # 示例 A: 中等熵 + target_entropy_mid = 1.6 + generated_probs_mid = explorer.find_distribution_for_entropy(target_entropy_mid) + filename_mid = f"entropy_{target_entropy_mid:.4f}_actions_{ACTION_SPACE}.png" + title_mid = f"熵约为 {target_entropy_mid} 的一个可能分布 (中熵)" + explorer.visualize_distribution( + generated_probs_mid, + title=title_mid, + save_dir=SAVE_DIRECTORY, + filename=filename_mid + ) + + # 示例 B: 非常低的熵 (接近确定性) + target_entropy_low = 0.2 + generated_probs_low = explorer.find_distribution_for_entropy(target_entropy_low) + filename_low = f"entropy_{target_entropy_low:.4f}_actions_{ACTION_SPACE}.png" + title_low = f"熵约为 {target_entropy_low} 的一个可能分布 (低熵)" + explorer.visualize_distribution( + generated_probs_low, + title=title_low, + save_dir=SAVE_DIRECTORY, + filename=filename_low + ) \ No newline at end of file diff --git a/lzero/model/eval_dinov2.py b/lzero/model/eval_dinov2.py new file mode 100644 index 000000000..8096fef3f --- /dev/null +++ b/lzero/model/eval_dinov2.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image +import requests +from io import BytesIO +import os # 导入 os 模块以检查文件是否存在 + +# ----------------------------------------------------------------------------- +# 1. DinoV2Encoder 定义 (保持不变) +# ----------------------------------------------------------------------------- +class DinoV2Encoder(nn.Module): + def __init__(self, name="dinov2_vits14", feature_key="x_norm_clstoken"): + super().__init__() + self.name = name + # 使用 torch.hub.set_dir() 来指定一个缓存目录,避免每次都下载 + # torch.hub.set_dir('/path/to/your/hub/cache') + self.base_model = torch.hub.load("facebookresearch/dinov2", name) + self.feature_key = feature_key + self.emb_dim = self.base_model.num_features + if feature_key == "x_norm_patchtokens": + self.latent_ndim = 2 + elif feature_key == "x_norm_clstoken": + self.latent_ndim = 1 + else: + raise ValueError(f"Invalid feature key: {feature_key}") + + self.patch_size = self.base_model.patch_size + + def forward(self, x): + """ + 输入 x 应该是已经预处理好的张量,shape: (B, 3, H, W) + """ + emb = self.base_model.forward_features(x)[self.feature_key] + if self.latent_ndim == 1: + emb = emb.unsqueeze(1) # dummy patch dim + return emb + +# ----------------------------------------------------------------------------- +# 2. 准备图像预处理函数 (已修复和优化) +# ----------------------------------------------------------------------------- +def prepare_image(image_source: str): + """ + 从给定的来源 (URL或本地路径) 加载图片并进行预处理。 + 支持 .jpg, .png, .webp 等 Pillow 支持的格式。 + """ + img = None + try: + # 判断输入是 URL 还是本地文件路径 + if image_source.startswith('http://') or image_source.startswith('https://'): + # --- 处理 URL --- + response = requests.get(image_source) + response.raise_for_status() # 如果下载失败则抛出异常 + img = Image.open(BytesIO(response.content)).convert("RGB") + else: + # --- 处理本地文件路径 --- + if not os.path.exists(image_source): + raise FileNotFoundError(f"本地文件未找到: {image_source}") + img = Image.open(image_source).convert("RGB") + + # DINOv2 推荐的预处理 + # Patch size 是 14,所以输入尺寸最好是 14 的倍数 + # 官方推荐尺寸是 518x518 + transform = transforms.Compose([ + transforms.Resize(518, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(518), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ]) + return transform(img).unsqueeze(0) # 增加 batch 维度 + + except Exception as e: + print(f"无法加载或处理图片 '{image_source}': {e}") + return None + +# ----------------------------------------------------------------------------- +# 3. 编写测试主函数 (已更新变量名和结论) +# ----------------------------------------------------------------------------- +def test_dinov2_pretrained_features(): + """ + 测试 DinoV2Encoder 是否加载了预训练权重并能提取有意义的语义特征。 + """ + print("正在初始化 DinoV2Encoder...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"使用设备: {device}") + + try: + encoder = DinoV2Encoder(name="dinov2_vits14") + encoder.to(device) + encoder.eval() # 设置为评估模式 + except Exception as e: + print(f"初始化 DinoV2Encoder 失败: {e}") + print("请检查您的网络连接或 PyTorch Hub 配置。") + return + + print("模型加载成功。准备加载并预处理图片...") + + # --- 图像来源 --- + # 您现在可以混合使用本地路径和 URL + cat_image_path_A = "/mnt/nfs/zhangjinouwen/puyuan/LightZero/lzero/model/cat_1.jpg" + cat_image_path_B = "/mnt/nfs/zhangjinouwen/puyuan/LightZero/lzero/model/cat_2.webp" + dog_image_path = "/mnt/nfs/zhangjinouwen/puyuan/LightZero/lzero/model/dog.webp" + # 为了演示,我们再加一个来自网络的汽车图片 + car_image_url = "https://images.unsplash.com/photo-1503376780353-7e6692767b70" + + + # 加载和预处理 + img_cat_A = prepare_image(cat_image_path_A) + img_cat_B = prepare_image(cat_image_path_B) + img_dog = prepare_image(dog_image_path) + img_car = prepare_image(car_image_url) + + # 检查所有图片是否加载成功 + if any(img is None for img in [img_cat_A, img_cat_B, img_dog, img_car]): + print("部分或全部图片加载失败,测试中止。") + return + + img_cat_A, img_cat_B, img_dog, img_car = [img.to(device) for img in [img_cat_A, img_cat_B, img_dog, img_car]] + + print("图片处理完成,开始提取特征向量...") + + with torch.no_grad(): + vec_cat_A = encoder(img_cat_A).squeeze(1) + vec_cat_B = encoder(img_cat_B).squeeze(1) + vec_dog = encoder(img_dog).squeeze(1) + vec_car = encoder(img_car).squeeze(1) + + print("特征提取完成,开始计算余弦相似度...") + + cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6) + + sim_cats = cos_sim(vec_cat_A, vec_cat_B) + sim_cat_dog = cos_sim(vec_cat_A, vec_dog) + sim_cat_car = cos_sim(vec_cat_A, vec_car) + + print("\n" + "="*50) + print("测试结果:") + print(f" - 两张'猫'图片的特征相似度: {sim_cats.item():.4f}") + print(f" - '猫'和'狗'图片的特征相似度: {sim_cat_dog.item():.4f}") + print(f" - '猫'和'汽车'图片的特征相似度: {sim_cat_car.item():.4f}") + print("="*50 + "\n") + + # 结论分析 + print("结论分析:") + # 预期:猫-猫 相似度 > 猫-狗 相似度 > 猫-车 相似度 + all_passed = True + if sim_cats.item() > sim_cat_dog.item(): + print("✅ 符合预期: '猫-猫' 相似度高于 '猫-狗'。") + else: + print("❌ 不符预期: '猫-猫' 相似度未高于 '猫-狗'。") + all_passed = False + + if sim_cat_dog.item() > sim_cat_car.item(): + print("✅ 符合预期: '猫-狗' 相似度高于 '猫-汽车'。") + else: + print("❌ 不符预期: '猫-狗' 相似度未高于 '猫-汽车'。") + all_passed = False + + if all_passed: + print("\n🎉 测试通过!结果表明模型能够区分不同语义层级的相似度。") + print(" 模型理解了'猫'和'狗'同属'动物'类别,比'汽车'更相似,但不如两只'猫'之间相似。") + else: + print("\n⚠️ 测试部分失败!请检查模型或图片质量。") + +if __name__ == '__main__': + # 为了运行此脚本,请确保已安装 requests 和 Pillow + # pip install requests Pillow + test_dinov2_pretrained_features() \ No newline at end of file diff --git a/lzero/model/muzero_model.py b/lzero/model/muzero_model.py index 75680ac06..edb82e85c 100644 --- a/lzero/model/muzero_model.py +++ b/lzero/model/muzero_model.py @@ -269,6 +269,22 @@ def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) """ next_latent_state, reward = self._dynamics(latent_state, action) policy_logits, value = self._prediction(next_latent_state) + + # print(f"next_latent_state.mean():{next_latent_state.mean()}") + # print(f"logits_rewards.mean():{reward.mean()}") + # print(f"logits_policy.mean():{policy_logits.mean()}") + # print(f"logits_value.mean():{value.mean()}") + + # with torch.no_grad(): + # l2_norm = torch.norm(next_latent_state, p=2, dim=1).mean() + # mean = next_latent_state.mean() + # std = next_latent_state.std() + # abs_max = next_latent_state.abs().max() + # # 假设您有logger + # # logger.add_scalar('debug/latent_l2_norm', l2_norm.item(), step_counter) + # # ... + # print(f"next Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}, Max Abs: {abs_max:.4f}") + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) def _representation(self, observation: torch.Tensor) -> torch.Tensor: diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 4ea6500f3..296d49650 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -4,11 +4,11 @@ import torch.nn as nn from ding.utils import MODEL_REGISTRY, SequenceType from easydict import EasyDict -from transformers import T5ForConditionalGeneration, T5Tokenizer +# from transformers import T5ForConditionalGeneration, T5Tokenizer from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ - HFLanguageRepresentationNetwork + HFLanguageRepresentationNetwork, DinoV2RepresentationNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size @@ -73,7 +73,7 @@ def __init__( self.action_space_size = action_space_size self.activation = activation self.downsample = downsample - world_model_cfg.norm_type = norm_type + world_model_cfg.norm_type = norm_type # NOTE===== assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' if world_model_cfg.obs_type == 'vector': @@ -118,17 +118,39 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) elif world_model_cfg.obs_type == 'image': - self.representation_network = RepresentationNetworkUniZero( - observation_shape, - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=world_model_cfg.embed_dim, - group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder - ) + if world_model_cfg.encoder_type=="resnet": + # ======================= + # 修改前的代码 + # ======================= + self.representation_network = RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder + ) + elif world_model_cfg.encoder_type=="dinov2": + + # ======================= + # 修改后的代码 + # ======================= + print("Using DinoV2RepresentationNetwork as the encoder.") + self.representation_network = DinoV2RepresentationNetwork( + observation_shape=observation_shape, + embedding_dim=world_model_cfg.embed_dim, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + group_size=world_model_cfg.group_size, + # 传递其他兼容性参数 + num_res_blocks=num_res_blocks, + num_channels=num_channels, + downsample=self.downsample, + activation=self.activation, + norm_type=norm_type, + ) # ====== for analysis ====== if world_model_cfg.analysis_sim_norm: @@ -217,6 +239,11 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torc policy_logits = logits_policy.squeeze(1) value = logits_value.squeeze(1) + # print(f"latent_state.mean():{latent_state.mean()}") + # print(f"logits_rewards.mean():{logits_rewards.mean()}") + # print(f"logits_policy.mean():{logits_policy.mean()}") + # print(f"logits_value.mean():{logits_value.mean()}") + return MZNetworkOutput( value=value, reward=[0. for _ in range(batch_size)], # Initialize reward to zero vector diff --git a/lzero/model/unizero_world_models/test_world_model_cache.py b/lzero/model/unizero_world_models/test_world_model_cache.py new file mode 100644 index 000000000..9eb53b272 --- /dev/null +++ b/lzero/model/unizero_world_models/test_world_model_cache.py @@ -0,0 +1,191 @@ +# test_world_model_cache.py +import torch +import torch.nn as nn +import numpy as np +from easydict import EasyDict +import csv +import os + +# 确保lzero和toy_env在Python路径中 +from lzero.model.unizero_world_models.world_model import WorldModel +from toy_env import ToyEnv +from lzero.model.unizero_world_models.utils import hash_state + +# ============================================================================== +# Helper classes and functions for the test +# ============================================================================== + +class DummyTokenizer: + """一个用于向量观测的简化分词器。""" + def __init__(self, obs_shape, embed_dim, device): + self.encoder = nn.Linear(obs_shape[0], embed_dim).to(device) + self.device = device + + def encode_to_obs_embeddings(self, obs): + obs_tensor = torch.from_numpy(obs).float().to(self.device) + if len(obs_tensor.shape) == 1: + obs_tensor = obs_tensor.unsqueeze(0) + if len(obs_tensor.shape) == 2: + return self.encoder(obs_tensor).unsqueeze(1) + elif len(obs_tensor.shape) == 3: + return self.encoder(obs_tensor).unsqueeze(2) + else: + raise ValueError(f"Unsupported observation tensor shape: {obs_tensor.shape}") + +def print_cache_summary(name: str, kv_cache, context_length: int): + """打印 KeysValues 缓存对象的摘要,并高亮显示截断行为。""" + if kv_cache is None: + print(f" {name}: None") + return 0, "None" + + size = kv_cache.size + shape = kv_cache._keys_values[0]._k_cache._cache.shape + status_msg = "" + # 模型在截断时会为未来的(act, obs)等留出空间,所以我们检查是否接近限制 + if size >= context_length - 3: + status_msg = f" (!! Approaching/Exceeded Context Limit of {context_length}. Truncation will occur.)" + + print(f" {name}: Size = {size}, Shape = {shape}{status_msg}") + return size, f"Size={size}" + +# ============================================================================== +# Main Test Function +# ============================================================================== +def test_cache_logic(): + # 1. 设置环境和模型配置 + env_cfg = ToyEnv.default_config() + env = ToyEnv(env_cfg) + + world_model_cfg = EasyDict( + dict( + continuous_action_space=False, num_layers=2, num_heads=4, embed_dim=64, + context_length=8, max_tokens=100, tokens_per_block=2, + action_space_size=env.cfg.action_space_size, env_num=1, obs_type='vector', + device='cuda' if torch.cuda.is_available() else 'cpu', rotary_emb=False, + policy_entropy_weight=0, predict_latent_loss_type='mse', group_size=8, + gamma=0.99, dormant_threshold=0.0, analysis_dormant_ratio=False, + latent_recon_loss_weight=0, perceptual_loss_weight=0, support_size=11, + max_cache_size=1000, final_norm_option_in_obs_head='SimNorm', norm_type='LN', + embed_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, max_blocks=10, gru_gating=False, + ) + ) + + # 2. 实例化世界模型 + tokenizer = DummyTokenizer(env.cfg.observation_shape, world_model_cfg.embed_dim, world_model_cfg.device) + world_model = WorldModel(world_model_cfg, tokenizer).to(world_model_cfg.device) + world_model.eval() + + # 3. 设置日志文件 + log_filename = "cache_log.csv" + log_filepath = os.path.join(os.getcwd(), log_filename) + print(f"\nLogging statistics to: {log_filepath}") + + with open(log_filepath, 'w', newline='') as log_file: + csv_writer = csv.writer(log_file) + header = [ + 'Timestep', 'Action_Taken', 'Current_State', + 'Root_Cache_Hit', 'Root_Cache_Size', + 'Recurrent_Cache_Hit', 'Recurrent_Cache_Size', + 'Comment' + ] + csv_writer.writerow(header) + + # 4. 运行一个 episode 并检查缓存 + obs_dict = env.reset() + last_action = -1 + last_obs_for_infer = np.zeros_like(obs_dict['observation']) + + for t in range(env.cfg.collect_max_episode_steps): + print(f"\n{'='*25} Timestep {t} {'='*25}") + print(f"Environment State: Obs = {obs_dict['observation']}, Timestep from Env = {obs_dict['timestep']}") + + log_row = {'Timestep': t, 'Current_State': str(obs_dict['observation'])} + + # --- 模拟 MCTS 搜索开始 --- + obs_act_dict = { + 'obs': last_obs_for_infer, 'action': np.array([last_action]), + 'current_obs': obs_dict['observation'] + } + print("\n[1. Initial Inference] -> Simulating root node creation for MCTS.") + print(f" Inputs: last_obs={obs_act_dict['obs']}, last_action={obs_act_dict['action']}, current_obs={obs_act_dict['current_obs']}") + + with torch.no_grad(): + # 注意:start_pos 应该是一个列表或数组,以适应模型的批处理逻辑 + _, latent_state, _, _, _ = world_model.forward_initial_inference( + obs_act_dict, start_pos=[obs_dict['timestep']] + ) + + # --- 检查根节点缓存 --- + print("\n[2. Inspecting Root Node Cache]") + cache_key = hash_state(latent_state.cpu().numpy().flatten()) + cache_index = world_model.past_kv_cache_init_infer_envs[0].get(cache_key) + + if cache_index is not None: + root_kv_cache = world_model.shared_pool_init_infer[0][cache_index] + log_row['Root_Cache_Hit'] = 'Stored' + size, _ = print_cache_summary("Stored Root KV Cache", root_kv_cache, world_model_cfg.context_length) + log_row['Root_Cache_Size'] = size + else: + log_row['Root_Cache_Hit'] = 'Not_Found' + log_row['Root_Cache_Size'] = 0 + print(" Status: Cache Not Found! (This is unexpected after the first step).") + + # --- 模拟一步 MCTS 循环推断 --- + action_to_take = env.action_space.sample() + log_row['Action_Taken'] = action_to_take + print(f"\n[3. Recurrent Inference] -> Simulating one search step from the root.") + print(f" Action to explore: {action_to_take}") + + state_action_history = [(latent_state.cpu().numpy(), np.array([action_to_take]))] + + print(" Checking if root cache is available for recurrent step...") + root_cache_key_for_recur = hash_state(state_action_history[0][0].flatten()) + root_cache_index = world_model.past_kv_cache_init_infer_envs[0].get(root_cache_key_for_recur) + if root_cache_index is not None: + log_row['Comment'] = 'Recurrent step found root cache.' + print(" -> Cache Hit! The recurrent step will build upon the existing root cache.") + else: + log_row['Comment'] = 'Recurrent step MISSES root cache!' + print(" -> Cache Miss! The recurrent step will have to regenerate context. (This indicates a problem)") + + with torch.no_grad(): + # 注意:start_pos 应该是一个列表或数组 + _, next_latent_state, _, _, _ = world_model.forward_recurrent_inference( + state_action_history, + start_pos=[obs_dict['timestep']] + ) + + # --- 检查循环推断节点的缓存 --- + print("\n[4. Inspecting Recurrent Node Cache]") + cache_key_recur = hash_state(next_latent_state.cpu().numpy().flatten()) + cache_index_recur = world_model.past_kv_cache_recurrent_infer.get(cache_key_recur) + if cache_index_recur is not None: + recurrent_kv_cache = world_model.shared_pool_recur_infer[cache_index_recur] + log_row['Recurrent_Cache_Hit'] = 'Stored' + size, _ = print_cache_summary("Stored Recurrent KV Cache", recurrent_kv_cache, world_model_cfg.context_length) + log_row['Recurrent_Cache_Size'] = size + else: + log_row['Recurrent_Cache_Hit'] = 'Not_Found' + log_row['Recurrent_Cache_Size'] = 0 + print(" Status: Recurrent Cache Not Found! (This is unexpected).") + + # --- 环境步进 --- + print("\n[5. Stepping Environment]") + timestep_obj = env.step(action_to_take) + + last_action = action_to_take + last_obs_for_infer = obs_dict['observation'] + obs_dict = timestep_obj.obs + + # 写入日志行 + csv_writer.writerow([log_row.get(h, '') for h in header]) + + if timestep_obj.done: + print("\n" + "="*20 + " Episode Finished " + "="*20) + break + + world_model.clear_caches() + print(f"\nTest finished. Log saved to {log_filepath}") + +if __name__ == "__main__": + test_cache_logic() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/toy_env.py b/lzero/model/unizero_world_models/toy_env.py new file mode 100644 index 000000000..66f220642 --- /dev/null +++ b/lzero/model/unizero_world_models/toy_env.py @@ -0,0 +1,113 @@ +# toy_env.py +import copy +from typing import List +import gym +import numpy as np +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.utils import ENV_REGISTRY +from easydict import EasyDict + +@ENV_REGISTRY.register('toy_lightzero') +class ToyEnv(BaseEnv): + """ + Overview: + A simple, deterministic toy environment for debugging KV cache and long-sequence processing in UniZero. + - State: 4-dim vector. + - Actions: 3 discrete actions (stay, increment, decrement). + - Episode Length: Fixed at 15 steps. + - Returns 'timestep' in observation. + """ + config = dict( + env_id='toy-v0', + env_type='Toy', + observation_shape=(4,), + action_space_size=3, + collect_max_episode_steps=15, + eval_max_episode_steps=15, + manager=dict(shared_memory=False), + stop_value=100, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict) -> None: + self.cfg = cfg + self._init_flag = False + self._observation_space = gym.spaces.Dict({ + 'observation': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.cfg.observation_shape, dtype=np.float32), + 'action_mask': gym.spaces.Box(low=0, high=1, shape=(self.cfg.action_space_size,), dtype=np.int8), + 'to_play': gym.spaces.Box(low=-1, high=2, shape=(), dtype=np.int8), + 'timestep': gym.spaces.Box(low=0, high=self.cfg.collect_max_episode_steps, shape=(), dtype=np.int32), + }) + self._action_space = gym.spaces.Discrete(self.cfg.action_space_size) + self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32) + + def reset(self) -> dict: + if not self._init_flag: + self._init_flag = True + self._state = np.zeros(self.cfg.observation_shape, dtype=np.float32) + self._episode_steps = 0 + self._eval_episode_return = 0.0 + return self.observe() + + def step(self, action: int) -> BaseEnvTimestep: + if action == 1: + self._state += 1 + elif action == 2: + self._state -= 1 + + self._episode_steps += 1 + reward = np.array([1.0], dtype=np.float32) + self._eval_episode_return += reward + + done = self._episode_steps >= self.cfg.collect_max_episode_steps + info = {} + if done: + info['eval_episode_return'] = self._eval_episode_return + + return BaseEnvTimestep(self.observe(), reward, done, info) + + def observe(self) -> dict: + return { + 'observation': self._state.copy(), + 'action_mask': np.ones(self.cfg.action_space_size, dtype=np.int8), + 'to_play': np.array(-1, dtype=np.int8), + 'timestep': np.array(self._episode_steps, dtype=np.int32) + } + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def close(self) -> None: + self._init_flag = False + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + def __repr__(self) -> str: + return "LightZero Toy Env" + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + return [cfg for _ in range(evaluator_env_num)] \ No newline at end of file diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index c2feb8497..7473875a7 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -207,6 +207,7 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, freqs_cis) # Apply final layer normalization x = self.ln_f(x) + return x @@ -246,6 +247,12 @@ def __init__(self, config: TransformerConfig) -> None: nn.Linear(4 * config.embed_dim, config.embed_dim), nn.Dropout(config.resid_pdrop), ) + + self.config = config + if self.config.res_alha: + # 为每个残差连接路径引入一个可学习的缩放因子,初始化为0 + self.alpha_attn = nn.Parameter(torch.zeros(1)) + self.alpha_mlp = nn.Parameter(torch.zeros(1)) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: @@ -266,8 +273,12 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None x = self.gate1(x, x_attn) x = self.gate2(x, self.mlp(self.ln2(x))) else: - x = x + x_attn - x = x + self.mlp(self.ln2(x)) + if not self.config.res_alha: + x = x + x_attn + x = x + self.mlp(self.ln2(x)) + else: + x = x + self.alpha_attn * x_attn + x = x + self.alpha_mlp * self.mlp(self.ln2(x)) return x diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 99c841cbe..87907fc91 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -201,7 +201,7 @@ class WorldModelOutput: logits_value: torch.FloatTensor -def init_weights(module, norm_type='BN'): +def init_weights(module, norm_type='BN',liner_weight_zero=False): """ Initialize the weights of the module based on the specified normalization type. @@ -209,14 +209,36 @@ def init_weights(module, norm_type='BN'): module (nn.Module): The module to initialize. norm_type (str): The type of normalization to use ('BN' for BatchNorm, 'LN' for LayerNorm). """ - if isinstance(module, (nn.Linear, nn.Embedding)): + if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: + elif isinstance(module, nn.Linear): + # 现在这个分支可以被正确执行了 + if norm_type == 'BN': + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + print("Init Linear using kaiming normal for BN") + elif norm_type == 'LN': + # 对于Transformer结构,Xavier/Glorot更常见 + nn.init.xavier_uniform_(module.weight) + print("Init Linear using xavier uniform for LN") + + if module.bias is not None: module.bias.data.zero_() + # if isinstance(module, (nn.Linear, nn.Embedding)): + # module.weight.data.normal_(mean=0.0, std=0.02) + + # # if liner_weight_zero and isinstance(module, nn.Linear): # TODO======== + # # nn.init.zeros_(module.weight) + + # if isinstance(module, nn.Linear) and module.bias is not None: + # module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") - module.bias.data.zero_() - module.weight.data.fill_(1.0) + try: + module.weight.data.fill_(1.0) + module.bias.data.zero_() + except Exception as e: + print(e) + elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) @@ -228,13 +250,13 @@ def init_weights(module, norm_type='BN'): elif norm_type == 'LN': nn.init.xavier_uniform_(module.weight) print(f"Init nn.Conv2d using xavier uniform for LN") - elif isinstance(module, nn.Linear): - if norm_type == 'BN': - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - print("Init Linear using kaiming normal for BN") - elif norm_type == 'LN': - nn.init.xavier_uniform_(module.weight) - print("Init Linear using xavier uniform for LN") + # elif isinstance(module, nn.Linear): + # if norm_type == 'BN': + # nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + # print("Init Linear using kaiming normal for BN") + # elif norm_type == 'LN': + # nn.init.xavier_uniform_(module.weight) + # print("Init Linear using xavier uniform for LN") class LossWithIntermediateLosses: @@ -253,17 +275,41 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu if not kwargs: raise ValueError("At least one loss must be provided") + # Get a reference device from one of the provided losses device = next(iter(kwargs.values())).device # NOTE: Define the weights for each loss type if not continuous_action_space: - # like EZV2, for atari and memory + # orig, for atari and memory + self.obs_loss_weight = 10 self.value_loss_weight = 0.5 self.reward_loss_weight = 1. self.policy_loss_weight = 1. self.ends_loss_weight = 0. + + # loss weight v1 ======TODO================== + # self.obs_loss_weight = 2.0 + # self.value_loss_weight = 1 + # self.reward_loss_weight = 1 + # self.policy_loss_weight = 1 + # self.ends_loss_weight = 0. + + # muzero loss weight + # self.obs_loss_weight = 2 + # self.value_loss_weight = 0.25 + # self.reward_loss_weight = 1 + # self.policy_loss_weight = 1 + # self.ends_loss_weight = 0. + + # EZV2, for atari and memory + # self.obs_loss_weight = 5 + # self.value_loss_weight = 0.5 + # self.reward_loss_weight = 1. + # self.policy_loss_weight = 1. + # self.ends_loss_weight = 0. + else: # like TD-MPC2 for DMC self.obs_loss_weight = 10 @@ -272,6 +318,21 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu self.policy_loss_weight = 0.1 self.ends_loss_weight = 0. + # TD-MPC2 for DMC, only for reference + # self.obs_loss_weight = 20 + # self.value_loss_weight = 0.1 + # self.reward_loss_weight = 0.1 + # self.ends_loss_weight = 0. + + # TODO(pu) + # self.latent_norm_loss_weight = 0.1 + # self.latent_norm_loss_weight = 0.01 + # self.latent_norm_loss_weight = 0.001 + # self.latent_norm_loss_weight = 0.0001 + self.latent_norm_loss_weight = 0.0 + + + self.latent_recon_loss_weight = latent_recon_loss_weight self.perceptual_loss_weight = perceptual_loss_weight @@ -292,6 +353,8 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu self.loss_total += self.latent_recon_loss_weight * v elif k == 'perceptual_loss': self.loss_total += self.perceptual_loss_weight * v + elif k == 'latent_norm_loss': + self.loss_total += self.latent_norm_loss_weight * v self.intermediate_losses = { k: v if isinstance(v, dict) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index e8df2a6e0..8ae120251 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -8,16 +8,101 @@ from einops import rearrange from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform -from lzero.model.common import SimNorm +from lzero.model.common import SimNorm, L2Norm from lzero.model.utils import cal_dormant_ratio from .kv_caching import KeysValues from .slicer import Head, PolicyHeadCont from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state - +from collections import OrderedDict logging.getLogger().setLevel(logging.DEBUG) +from collections import OrderedDict, defaultdict +import matplotlib.pyplot as plt +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +from sklearn.manifold import TSNE +# In unizero_world_model.py +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import os +import datetime +import torch +import torch.nn as nn +from lzero.model.common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook, MLP_V2 + +def inspect_model_parameters(model: nn.Module): + """ + 遍历模型的所有参数并打印其统计信息(均值、标准差、最大值、最小值)。 + """ + print("--- Inspecting Initial Model Parameters ---") + total_params = 0 + with torch.no_grad(): + for name, param in model.named_parameters(): + if param.requires_grad: + total_params += param.numel() + mean = param.mean().item() + std = param.std().item() + abs_max = param.abs().max().item() + + # 打印关键统计数据,帮助判断初始化是否合理 + print(f"{name:<50} | Shape: {str(param.shape):<25} | Mean: {mean:+.4f} | Std: {std:.4f} | MaxAbs: {abs_max:.4f}") + + print(f"--- Total Trainable Parameters: {total_params/1e6:.2f}M ---") + + +# --- HOOK FUNCTION FOR DEBUGGING --- +def print_intermediate_activation_hook(module, input, output): + """ + A PyTorch hook that prints the mean and std of a module's output. + This function will be registered to a specific layer (e.g., the first Linear layer in a Head). + + Args: + module: The module the hook is registered on. + input: The input to the module's forward pass. + output: The output from the module's forward pass. + """ + # output is the tensor we want to inspect + mean = output.mean().item() + std = output.std().item() + # We add the module name for clarity, to know which layer's output we are seeing. + print(f" [HOOK DEBUG] Layer '{module.__class__.__name__}' Output -> mean: {mean:.6f}, std: {std:.6f}") + + + + +class LRUCache(OrderedDict): + """ + 一个固定容量的、遵循LRU(最近最少使用)原则的有序字典。 + 非常适合用于管理与环形缓冲区同步的缓存映射。 + """ + def __init__(self, capacity: int=2): + """ + 初始化LRU缓存。 + 参数: + - capacity (int): 缓存的最大容量。 + """ + self.capacity = capacity + super().__init__() + + def __setitem__(self, key: Any, value: Any) -> None: + """ + 重写设置条目的方法,以实现LRU逻辑。 + """ + # 如果键已存在,先删除旧条目,以确保后续添加时它会成为最新项。 + if key in self: + self.move_to_end(key) + + # 调用父类的方法来实际设置键值对。 + super().__setitem__(key, value) + + # 检查是否超出容量。如果超出,则删除最旧的条目。 + # popitem(last=False) 会移除并返回字典中第一个(最旧的)条目。 + if len(self) > self.capacity: + self.popitem(last=False) class WorldModel(nn.Module): """ @@ -61,7 +146,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # Position embedding if not self.config.rotary_emb: + # self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + # TODO(pu) self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + # self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device, max_norm=1.0) self.precompute_pos_emb_diff_kv() print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") @@ -75,13 +163,17 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: SimNorm(simnorm_dim=self.group_size)) else: # for discrete action space + # self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + # TODO(pu) self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + # self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device, max_norm=1.0) + logging.info(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") - self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm') # Head modules - self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size, use_norm_in_head=True) # TODO self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ self._get_final_norm(self.final_norm_option_in_obs_head) # NOTE: using the specified normalization method for observations head ) @@ -91,7 +183,28 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.action_space_size) else: self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) - self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size, use_norm_in_head=True) + + # # ==================== NEW DEBUGGING CODE VIA HOOKS ==================== + # # We will attach our hook to the first Linear layer inside the head_value and head_rewards modules. + # # The head_module is an nn.Sequential, so its layers can be accessed by index. + # # Index 0: First nn.Linear + # # Index 1: nn.GELU + # # Index 2: Second nn.Linear + + # # Get the first linear layer from the sequential module + # first_linear_layer_value = self.head_value.head_module[0] + # first_linear_layer_rewards = self.head_rewards.head_module[0] + + # # Register the forward hook + # print("--- Attaching DEBUG hooks to head_value and head_rewards ---") + # self.value_hook_handle = first_linear_layer_value.register_forward_hook(print_intermediate_activation_hook) + # self.rewards_hook_handle = first_linear_layer_rewards.register_forward_hook(print_intermediate_activation_hook) + + # # NOTE: It's good practice to store the hook handle so you can remove it later if needed, e.g., during evaluation or after debugging. + # # To remove the hook: self.value_hook_handle.remove() + # # ==================================================================== + # Build the set of modules to skip during re-initialization. # This is compatible with cases where self.tokenizer.encoder does not have 'pretrained_model', @@ -101,7 +214,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: if hasattr(self.tokenizer.encoder, 'pretrained_model'): skip_modules.update(self.tokenizer.encoder.pretrained_model.modules()) if hasattr(self.tokenizer, 'decoder_network'): - skip_modules.update(self.tokenizer.decoder_network.modules()) + if self.tokenizer.decoder_network is not None: + skip_modules.update(self.tokenizer.decoder_network.modules()) def custom_init(module): # If the current module is part of the skip list, return without reinitializing @@ -113,8 +227,28 @@ def custom_init(module): # Recursively apply `custom_init` to all submodules of the model self.apply(custom_init) + # self.apply(init_weights) + self._initialize_last_layer() + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # 先设置为game_segment_length,以保持self.shared_pool_init_infer都是有效的kv + # TODO: 非常重要,应该改为和segment_length一样 + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + # self.shared_pool_size_init = int(20) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # self.shared_pool_size_init = int(200) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + self.num_simulations = getattr(self.config, 'num_simulations', 50) + + # TODO: recur kv pool是否应该分成不同的环境有不同的pool呢 + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + + # self.shared_pool_size_init = int(50) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + self.stale_pointer_detections = 0 + self.stale_pointer_detections_recur = 0 # Cache structures self._initialize_cache_structures() @@ -133,15 +267,19 @@ def custom_init(module): # TODO: check the size of the shared pool # for self.kv_cache_recurrent_infer # If needed, recurrent_infer should store the results of the one MCTS search. - self.num_simulations = getattr(self.config, 'num_simulations', 50) - self.shared_pool_size = int(self.num_simulations*self.env_num) - self.shared_pool_recur_infer = [None] * self.shared_pool_size + + # self.shared_pool_size_recur = int(2) + + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur self.shared_pool_index = 0 # for self.kv_cache_init_infer # In contrast, init_infer only needs to retain the results of the most recent step. - # self.shared_pool_size_init = int(2*self.env_num) - self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # TODO + # self.shared_pool_size_init = int(50) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + # TODO: 分析self.env_num>1的情况,不同env之间的相同latent-state hash对应的kv_cache可以公用吗 self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] @@ -152,6 +290,305 @@ def custom_init(module): self.reanalyze_phase = False + # 用于t-SNE可视化的计数器 + self.tsne_visualization_step = 0 + + # 用于存储梯度hook的handle + self._grad_hooks = [] + + if self.config.entry_norm: + self.entry_norm = nn.LayerNorm(config.embed_dim) # <-- 新增 + + # ======================= 注册梯度Hooks ======================= + # self.register_gradient_hooks(self.tokenizer.representation_network) + # ============================================================= + + # ==================== START: LEARNABLE TEMPERATURE SCALING ==================== + # 添加一个选项来启用或禁用温度缩放 + self.use_temperature_scaling = getattr(self.config, 'use_temperature_scaling', True) # 默认开启 + + if self.use_temperature_scaling: + # 为 value, reward, policy 分别初始化一个可学习的对数温度参数 + # 初始化为0,意味着初始温度 T = exp(0) = 1,不影响训练开始 + self.log_temp_value = nn.Parameter(torch.zeros([])) + self.log_temp_reward = nn.Parameter(torch.zeros([])) + self.log_temp_policy = nn.Parameter(torch.zeros([])) + logging.info("Learnable temperature scaling for prediction heads is ENABLED.") + # ===================== END: LEARNABLE TEMPERATURE SCALING ===================== + + # 2. 初始化损失函数 (通常在模型 __init__ 中完成) + # 使用 reduction='none' 来获取每个token的损失,以便后续应用掩码 + self.ce_loss_fn_pt = torch.nn.CrossEntropyLoss(reduction='none') + + + def register_gradient_hooks(self, model_to_hook: nn.Module): + """ + 递归地为模型中的可学习参数注册梯度hook。 + """ + + def hook_fn(grad): + # 这个hook会在该参数的梯度被计算出来后立即执行 + if grad is not None: + grad_norm = grad.norm().item() + grad_mean = grad.mean().item() + grad_std = grad.std().item() + # 为了避免信息过载,我们可以只打印非零梯度的统计信息 + if grad_norm > 1e-9: + print(f" [GRAD HOOK] Param: {name}, Shape: {grad.shape} | Norm: {grad_norm:.6f}, Mean: {grad_mean:.6f}, Std: {grad_std:.6f}") + + # 遍历模型的所有命名参数 + for name, param in model_to_hook.named_parameters(): + if param.requires_grad: + # 使用 .register_hook() 为张量注册hook + handle = param.register_hook(hook_fn) + self._grad_hooks.append(handle) + print(f" [INFO] Registered gradient hook for: {name}") + + def remove_gradient_hooks(self): + """ + 移除所有已注册的梯度hook,在评估或部署时调用。 + """ + for handle in self._grad_hooks: + handle.remove() + self._grad_hooks.clear() + print("[INFO] All gradient hooks removed.") + + # ==================== 新增辅助方法 ==================== + def _inspect_and_log_head_params(self, head_name: str, head_module: nn.Module, status: str): + """ + 检查并记录指定Head模块的参数统计信息。 + + Args: + head_name (str): 要检查的Head的名称 (例如, "Value Head")。 + head_module (nn.Module): Head的实际nn.Sequential模块。 + status (str): 描述当前状态的字符串 (例如, "Before Re-init")。 + """ + logging.info(f"--- 检查 {head_name} 参数 ({status}) ---") + with torch.no_grad(): + for param_name, param in head_module.named_parameters(): + if param.numel() > 0: + stats = { + "mean": param.mean().item(), + "std": param.std().item(), + "abs_mean": param.abs().mean().item(), + "max": param.max().item(), + "min": param.min().item(), + } + logging.info( + f" -> {param_name:<20} | " + f"Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}, " + f"AbsMean: {stats['abs_mean']:.4f}, " + f"Max: {stats['max']:.4f}, Min: {stats['min']:.4f}" + ) + logging.info("-" * (23 + len(head_name) + len(status))) + +# ==================== 修改后的方法 ==================== + def reinit_prediction_heads(self, heads_to_reinit: List[str] = ['value', 'reward']) -> None: + """ + 重新初始化指定的预测头(例如Value Head和Reward Head)的参数。 + 在重新初始化前后,会记录参数的统计信息以供分析。 + + Args: + heads_to_reinit (List[str]): 一个包含要重新初始化的头的名称的列表。 + 默认为 ['value', 'reward']。 + """ + logging.info(f"开始重新初始化预测头: {heads_to_reinit}") + + head_map = { + 'value': self.head_value, + 'reward': self.head_rewards, + 'policy': self.head_policy, + } + + def _init_weights_for_head(module): + # TODO + init_weights(module, norm_type=self.config.norm_type, liner_weight_zero=True) + + for head_name in heads_to_reinit: + if head_name in head_map and hasattr(head_map[head_name], 'head_module'): + head_instance = head_map[head_name] + capitalized_name = head_name.capitalize() + " Head" + + # 1. 重新初始化前检查参数 + self._inspect_and_log_head_params(capitalized_name, head_instance.head_module, "Before Re-init") + + # 2. 应用重新初始化 + logging.info(f"正在重新初始化 {capitalized_name}...") + head_instance.head_module.apply(_init_weights_for_head) + + # 3. 重新初始化后再次检查参数 + self._inspect_and_log_head_params(capitalized_name, head_instance.head_module, "After Re-init") + + logging.info(f"{capitalized_name} 参数已成功重新初始化。") + else: + logging.warning(f"未能找到名为 '{head_name}' 的预测头或其 'head_module'。跳过。") + + logging.info("所有指定的预测头重新初始化完成。") + # ========================================================== + + def _analyze_latent_representation( + self, + latent_states: torch.Tensor, + timesteps: torch.Tensor, + game_states: torch.Tensor, + predicted_values: torch.Tensor, + predicted_rewards: torch.Tensor, + step_counter: int + ): + """ + 分析并记录 latent states 的统计信息和t-SNE可视化。 + 【新功能】:在t-SNE图上显示对应的游戏图像,并标注预测的Value和Reward。 + 【已修改】:如果保存路径已存在同名文件,则在文件名后附加时间戳。 + + Args: + latent_states (torch.Tensor): Encoder的输出, shape (B*L, 1, E) + timesteps (torch.Tensor): 对应的时间步, shape (B, L) + game_states (torch.Tensor): 原始的游戏观测, shape (B, L, C, H, W) + predicted_values (torch.Tensor): 预测的标量Value, shape (B*L,) + predicted_rewards (torch.Tensor): 预测的标量Reward, shape (B*L,) + step_counter (int): 全局训练步数 + """ + # ... (统计分析部分保持不变) ... + # (确保 latent_states 和 game_states 的形状为 (N, ...)) + if latent_states.dim() > 2: + latent_states = latent_states.reshape(-1, latent_states.shape[-1]) + num_c, num_h, num_w = game_states.shape[-3:] + game_states = game_states.reshape(-1, num_c, num_h, num_w) + + with torch.no_grad(): + l2_norm = torch.norm(latent_states, p=2, dim=1).mean() + mean = latent_states.mean() + std = latent_states.std() + print(f"[Step {step_counter}] Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}") + + # 带图像和V/R值的 t-SNE 可视化 + if step_counter >= 0: + # if step_counter > 0 and step_counter % 200 == 0: + + print(f"[Step {step_counter}] Performing t-SNE analysis with images, values, and rewards...") + + # 将数据转换到CPU + latents_np = latent_states.detach().cpu().numpy() + images_np = game_states.detach().cpu().numpy() + values_np = predicted_values.detach().cpu().numpy() + rewards_np = predicted_rewards.detach().cpu().numpy() + + tsne = TSNE(n_components=2, perplexity=30, n_iter=300, random_state=42) + tsne_results = tsne.fit_transform(latents_np) + + # --- 绘制带图像和标注的散点图 --- + + # 减少图像数量以保持清晰 + num_points_to_plot = min(len(latents_np), 70) # 减少到70个点 + indices = np.random.choice(len(latents_np), num_points_to_plot, replace=False) + + fig, ax = plt.subplots(figsize=(20, 18)) # 增大画布尺寸 + + # 先画出所有点的散点图作为背景 + ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=values_np, cmap='viridis', alpha=0.3, s=10) + + for i in indices: + x, y = tsne_results[i] + img = images_np[i].transpose(1, 2, 0) + img = np.clip(img, 0, 1) + + # 放置图像 + im = OffsetImage(img, zoom=0.7) # 稍微放大图像 + ab = AnnotationBbox(im, (x, y), frameon=True, pad=0.0, bboxprops=dict(edgecolor='none')) + ax.add_artist(ab) + + # 在图像下方添加文字标注 + text_label = f"V:{values_np[i]:.1f} R:{rewards_np[i]:.1f}" + ax.text(x, y - 1.0, text_label, ha='center', va='top', fontsize=8, color='red', + bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.5)) + + ax.update_datalim(tsne_results) + ax.autoscale() + + ax.set_title(f't-SNE of Latent States (Value as Color) at Step {step_counter}', fontsize=16) + ax.set_xlabel('t-SNE dimension 1', fontsize=12) + ax.set_ylabel('t-SNE dimension 2', fontsize=12) + + # 添加colorbar来解释背景点的颜色 + norm = plt.Normalize(values_np.min(), values_np.max()) + sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm) + sm.set_array([]) + fig.colorbar(sm, ax=ax, label='Predicted Value') + + # --- 修改部分:检查文件是否存在,如果存在则添加时间戳 --- + # 1. 构建基础路径 + # base_save_path = ( + # f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + # f'tsne_with_vr_{self.config.optim_type}_lr{self.config.learning_rate}_step_{step_counter}.png' + # ) + base_save_path = ( + f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + f'tsne_with_vr_{self.config.optim_type}_step_{step_counter}.png' + ) + + # 2. 检查文件是否存在,并确定最终保存路径 + if os.path.exists(base_save_path): + # 如果文件已存在,则生成时间戳并附加到文件名 + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + path_root, path_ext = os.path.splitext(base_save_path) + save_path = f"{path_root}_{timestamp}{path_ext}" + print(f"File '{base_save_path}' already exists. Saving to new path with timestamp.") + else: + # 如果文件不存在,则使用原始路径 + save_path = base_save_path + + # 3. 保存图像 + plt.savefig(save_path) + plt.close(fig) # 明确关闭图形对象 + print(f"t-SNE plot with V/R annotations saved to {save_path}") + + def _debug_check_for_stale_pointers(self, env_id: int, current_key: Any, index_to_be_written: int): + """ + 调试函数:检查即将被写入的索引是否存在过时的指针。 + """ + # 获取对应环境的指针映射表 + cache_map = self.past_kv_cache_init_infer_envs[env_id] + + # 遍历映射表中的所有条目 (旧哈希 -> 旧索引) + for old_key, old_index in cache_map.items(): + # 检查条件: + # 1. 旧索引 == 即将被覆盖的索引 + # 2. 旧哈希 != 当前要写入的新哈希 + if old_index == index_to_be_written and old_key != current_key: + # 如果条件满足,说明我们找到了一个过时指针 + self.stale_pointer_detections += 1 + + # 打印详细的调试信息 + print("="*60) + print(f"!!! INIT BUG CONDITION DETECTED (Detection #{self.stale_pointer_detections}) !!!") + print(f" Environment ID: {env_id}") + print(f" Pool Index to be overwritten: {index_to_be_written}") + print(f" New state hash being written: '{current_key}'") + print(f" Stale pointer found in cache_map: '{old_key}' also points to index {old_index}.") + print(f" This means the data for '{old_key}' is about to be lost, but its pointer remains.") + print(f" Current cache_map size: {len(cache_map)}") + print("="*60) + + # 找到一个就足够了,可以提前退出循环以提高效率 + return + + def _debug_check_for_stale_pointers_recur(self, current_key: Any, index_to_be_written: int): + """ + 调试函数:检查 recurrent cache 中是否存在过时的指针。 + """ + cache_map = self.past_kv_cache_recurrent_infer + + for old_key, old_index in cache_map.items(): + if old_index == index_to_be_written and old_key != current_key: + self.stale_pointer_detections_recur += 1 + print("="*60) + print(f"!!! RECURRENT BUG DETECTED (Detection #{self.stale_pointer_detections_recur}) !!!") + print(f" Pool Index to be overwritten: {index_to_be_written}") + print(f" New state hash being written: '{current_key}'") + print(f" Stale pointer found: '{old_key}' also points to index {old_index}.") + print("="*60) + return + def _get_final_norm(self, norm_option: str) -> nn.Module: """ Return the corresponding normalization module based on the specified normalization option. @@ -160,6 +597,8 @@ def _get_final_norm(self, norm_option: str) -> nn.Module: return nn.LayerNorm(self.config.embed_dim, eps=1e-5) elif norm_option == 'SimNorm': return SimNorm(simnorm_dim=self.config.group_size) + elif norm_option == 'L2Norm': + return L2Norm(eps=1e-6) else: raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") @@ -264,7 +703,7 @@ def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: dst_layer._v_cache._size = src_layer._v_cache._size index = self.shared_pool_index - self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size_recur return index @@ -301,15 +740,83 @@ def _initialize_patterns(self) -> None: self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) self.value_policy_tokens_pattern[-2] = 1 - def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + # def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + # """Create head modules for the transformer.""" + # modules = [ + # nn.Linear(self.config.embed_dim, self.config.embed_dim), + # nn.GELU(approximate='tanh'), + # nn.Linear(self.config.embed_dim, output_dim) + # ] + # if norm_layer: + # modules.append(norm_layer) + # return Head( + # max_blocks=self.config.max_blocks, + # block_mask=block_mask, + # head_module=nn.Sequential(*modules) + # ) + + #_create_head_muzero + # def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, use_norm_in_head: bool = False) -> Head: + # """Create head modules for the transformer.""" + + + # modules = MLP_V2( + # in_channels=self.config.embed_dim, + # hidden_channels=[32], + # out_channels=output_dim, + # # activation=nn.GELU(approximate='tanh'), + # activation=nn.ReLU(inplace=True), + # # norm_type='BN', + # norm_type='LN', + # output_activation=False, + # output_norm=False, + # last_linear_layer_init_zero=True + # ) + + # return Head( + # max_blocks=self.config.max_blocks, + # block_mask=block_mask, + # head_module=nn.Sequential(*modules) + # ) + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, use_norm_in_head: bool = False) -> Head: """Create head modules for the transformer.""" + # modules = [ + # nn.Linear(self.config.embed_dim, self.config.embed_dim), + # ] + + # ==================== 头部优化:防御性设计 ==================== + # 在头部入口处增加一个LayerNorm,以防止输入饱和。 modules = [ + nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO nn.Linear(self.config.embed_dim, self.config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(self.config.embed_dim, output_dim) + # ==================== 核心修复点 ==================== + # 在激活函数GELU之前,对第一个线性层的输出进行归一化 + # 这可以防止激活值爆炸或饱和 + nn.LayerNorm(self.config.embed_dim), # 2. <-- 新增!稳定内部激活 + # ====================================================== ] + # ============================================================= + + + # ==================== PROPOSED FIX ==================== + # Add a LayerNorm after the first linear layer and before the activation. + # This stabilizes the activations within the head, preventing drift. + # if use_norm_in_head: # TODO + # modules.append(nn.LayerNorm(self.config.embed_dim)) + # ====================================================== + + modules.extend([ + nn.GELU(approximate='tanh'), + # nn.ReLU(inplace=True), + nn.Linear(self.config.embed_dim, output_dim), + # 最后的LayerNorm可以保留,也可以视情况移除,因为它主要影响输出的尺度 + # nn.LayerNorm(output_dim) + ]) + if norm_layer: modules.append(norm_layer) + return Head( max_blocks=self.config.max_blocks, block_mask=block_mask, @@ -338,6 +845,8 @@ def _create_head_cont(self, block_mask: torch.Tensor, output_dim: int, norm_laye def _initialize_last_layer(self) -> None: """Initialize the last linear layer.""" last_linear_layer_init_zero = True # TODO + # last_linear_layer_init_zero = False # TODO===== + if last_linear_layer_init_zero: if self.continuous_action_space: module_to_initialize = [self.head_value, self.head_rewards, self.head_observations] @@ -354,8 +863,61 @@ def _initialize_last_layer(self) -> None: def _initialize_cache_structures(self) -> None: """Initialize cache structures for past keys and values.""" from collections import defaultdict - self.past_kv_cache_recurrent_infer = defaultdict(dict) - self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + # self.past_kv_cache_recurrent_infer = defaultdict(dict) + # 使用 LRUCache 替换 defaultdict,并同步容量 + + # ========================= 核心修复与注释 (Recurrent Infer) ========================= + # 问题: recurrent_infer 缓存同样存在 LRUCache 与环形缓冲区逻辑不匹配的问题。 + # + # 修复方案: + # 1. 将 past_kv_cache_recurrent_infer 从 LRUCache 改为标准字典。 + # 2. 引入辅助列表 pool_idx_to_key_map_recur_infer 来维护反向映射。 + # 这确保了在覆写 recurrent 数据池中的条目时,可以同步删除旧的指针。 + + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + # ========================== 修复结束 ========================== + + # self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + + # TODO(pu): 非常重要 self.past_kv_cache_init_infer_envs应该改成和(shared_pool_size_init)完全一致, + # 目前是将shared_pool_size_init设置为segment_length以在一次collect后 清空self.past_kv_cache_init_infer_envs + # 来避免self.past_kv_cache_init_infer_envs里面存有kv索引过期的问题 + + # ========================= 核心修复与注释 ========================= + # 原来的实现: + # self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + # + # 问题: defaultdict 会无限增长,并且不会自动删除与环形缓冲区中 + # 被覆盖数据相关的旧“指针”,导致Episode内部的缓存污染。 + # + # 修复方案: + # 使用我们定义的LRUCache,其容量与环形缓冲区的大小(shared_pool_size_init)完全一致。 + # + # 效果: + # 1. 自动淘汰: 当添加第 N+1 个新条目时,LRUCache会自动删除最旧的那个条目。 + # 2. 生命周期同步: 这确保了“指针字典”中的映射关系,与“数据池”中实际存储的数据 + # 完全同步。当数据池的索引0被新数据覆盖时,指向旧索引0的指针也已被自动清除。 + # 3. 杜绝污染: 从根本上解决了Episode内部的状态哈希碰撞问题。 + + # self.past_kv_cache_init_infer_envs = [LRUCache(self.shared_pool_size_init-1) for _ in range(self.env_num)] + # ========================== 修复结束 ========================== + + # ========================= 核心修复与注释 ========================= + # 问题: LRUCache 的淘汰逻辑(基于访问顺序)与环形缓冲区的覆写逻辑(基于写入顺序)不匹配,导致指针过时。 + # + # 修复方案: + # 1. 使用一个标准的字典 `past_kv_cache_init_infer_envs` 来存储 {state_hash -> pool_index}。 + # 2. 引入一个辅助列表 `pool_idx_to_key_map_init_envs` 来维护反向映射 {pool_index -> state_hash}。 + # + # 效果: + # 在向环形缓冲区的某个索引写入新数据之前,我们可以通过辅助列表立即找到即将被覆盖的旧 state_hash, + # 并从主字典中精确地删除这个过时的条目。这确保了字典和数据池的完全同步。 + + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # 辅助数据结构,用于反向查找:pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + # ========================== 修复结束 ========================== self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] @@ -369,10 +931,12 @@ def _initialize_projection_input_dim(self) -> None: def _initialize_statistics(self) -> None: """Initialize counters for hit count and query count statistics.""" - self.hit_count = 0 - self.total_query_count = 0 + self.recur_hit_count = 0 + self.recur_total_query_count = 0 self.length_largethan_maxminus5_context_cnt = 0 self.length_largethan_maxminus7_context_cnt = 0 + self.length_largethan_contextminus3_cnt = 0 + self.root_hit_cnt = 0 self.root_total_query_cnt = 0 @@ -506,6 +1070,9 @@ def forward( # Process observation embeddings if available. if "obs_embeddings" in obs_embeddings_or_act_tokens: obs_embeddings = obs_embeddings_or_act_tokens["obs_embeddings"] + if self.config.entry_norm: + obs_embeddings = self.entry_norm(obs_embeddings) # <-- 新增 TODO + # If the observation embeddings have 2 dimensions, expand them to include a time dimension. if len(obs_embeddings.shape) == 2: obs_embeddings = obs_embeddings.unsqueeze(1) @@ -568,6 +1135,10 @@ def forward( num_steps = act_tokens.size(1) # Convert action tokens to embeddings using the action embedding table. act_embeddings = self.act_embedding_table(act_tokens) + if self.config.entry_norm: + act_embeddings = self.entry_norm(act_embeddings) # <-- 新增 TODO + + if not self.config.rotary_emb: sequences = self._add_position_embeddings( act_embeddings, prev_steps, num_steps, kvcache_independent, @@ -617,17 +1188,42 @@ def forward( else: raise ValueError("Input dictionary must contain one of 'obs_embeddings', 'act_tokens', or 'obs_embeddings_and_act_tokens'.") + # # ==================== 核心修复:应用入口归一化 ==================== + # # 在添加位置编码之前,对拼接后的序列进行LayerNorm + # sequences = self.entry_norm(sequences) + # # ================================================================= + # Pass the sequence through the transformer. x = self._transformer_pass( sequences, past_keys_values, kvcache_independent, valid_context_lengths, start_pos=start_pos_adjusted ) + # print(f"x.mean(): {x.mean().item():.6f}, x.std(): {x.std().item():.6f}") + # Generate logits for various components. logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + if self.use_temperature_scaling and not self.training: + # ==================== 关键修改点 ==================== + # collect/eval/计算atrget时,应用可学习的温度. 在训练时是在compute loss中执行temperature_scaling + with torch.no_grad(): + T_policy = 1.0 + F.softplus(self.log_temp_policy) + T_value = 1.0 + F.softplus(self.log_temp_value) + T_reward = 1.0 + F.softplus(self.log_temp_reward) + + logits_policy /= (T_policy + 1e-8) + logits_value /= (T_value + 1e-8) + logits_rewards /= (T_reward + 1e-8) + # ==================================================== + + # print(f"logits_observations.mean(): {logits_observations.mean().item():.6f}") + # print(f"logits_rewards.mean(): {logits_rewards.mean().item():.6f}") + # print(f"logits_policy.mean(): {logits_policy.mean().item():.6f}") + # print(f"logits_value.mean(): {logits_value.mean().item():.6f}") + # The 'logits_ends' is intentionally set to None. return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) @@ -693,6 +1289,9 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step obs_act = torch.cat([obs, act], dim=1) obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + if self.config.entry_norm: + obs_act_embeddings = self.entry_norm(obs_act_embeddings) # <-- 新增 TODO + return_result = obs_act_embeddings if not self.config.rotary_emb: return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) @@ -724,7 +1323,12 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): act = act_embeddings[:, i, 0, :].unsqueeze(1) obs_act = torch.cat([obs, act], dim=1) obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - + + + if self.config.entry_norm: + obs_act_embeddings = self.entry_norm(obs_act_embeddings) # <-- 新增 TODO + + return_result = obs_act_embeddings if not self.config.rotary_emb: return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) @@ -742,6 +1346,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va Returns: - torch.Tensor: Transformer output. """ + if kvcache_independent: x = [self.transformer(sequences[k].unsqueeze(0), past_kv, valid_context_lengths=valid_context_lengths[k].unsqueeze(0), start_pos=start_pos) for k, past_kv in @@ -843,6 +1448,11 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens if matched_value is not None: # If a matching value is found, add it to the list self.root_hit_cnt += 1 + # if self.root_total_query_cnt > 0 and self.root_total_query_cnt % 50 == 0: + # self.root_hit_freq = self.root_hit_cnt / self.root_total_query_cnt + # print('root total_query_count:', self.root_total_query_cnt) + # print('root root_hit_freq:', self.root_hit_freq) + # NOTE: deepcopy is needed because forward modifies matched_value in place self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) self.keys_values_wm_size_list.append(matched_value.size) @@ -934,7 +1544,18 @@ def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): """ # UniZero has context in the root node outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, start_pos) + # TODO(pu): 由于预测误差的存在,不clear,也很可能不能检索到上次mcts 树搜索中的节点 + # 所有collect env公用应该也是合理的,不同环境很难遇到完全一致的预测的latent state? + # self.past_kv_cache_recurrent_infer.clear() + + # ==================== 正确的修复位置 ==================== + # 在每次新的MCTS搜索(即调用initial_inference)开始时, + # 清除上一次搜索遗留的 recurrent (MCTS) 缓存。 self.past_kv_cache_recurrent_infer.clear() + if hasattr(self, 'pool_idx_to_key_map_recur_infer'): + # 同时也要清理辅助映射表 + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + # ========================================================= return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) @@ -966,20 +1587,26 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, token = action.reshape(-1, self.action_space_size) # ======= Print statistics for debugging ============= - # min_size = min(self.keys_values_wm_size_list) - # if min_size >= self.config.max_tokens - 5: - # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) - # if min_size >= self.config.max_tokens - 7: - # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) - # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: - # self.hit_freq = self.hit_count / self.total_query_count - # print('total_query_count:', self.total_query_count) - # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count - # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) - # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) - # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count - # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) - # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + min_size = min(self.keys_values_wm_size_list) + # # if min_size >= self.config.max_tokens - 5: + # # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # # if min_size >= self.config.max_tokens - 7: + # # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.context_length - 3: + # self.length_largethan_contextminus3_cnt += len(self.keys_values_wm_size_list) + # # if self.recur_total_query_count > 0 and self.recur_total_query_count % 10000 == 0: + # if self.recur_total_query_count > 0 and self.recur_total_query_count % 1000 == 0: + # self.hit_freq = self.recur_hit_count / self.recur_total_query_count + # print('recur total_query_count:', self.recur_total_query_count) + # # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.recur_total_query_count + # # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.recur_total_query_count + # # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + # length_largethan_contextminus3_cnt_ratio = self.length_largethan_contextminus3_cnt / self.recur_total_query_count + # print('recurrent length_largethan_contextminus3_cnt_ratio:', length_largethan_contextminus3_cnt_ratio) + # print('recurrent length_largethan_contextminus3_cnt:', self.length_largethan_contextminus3_cnt) # Trim and pad kv_cache: modify self.keys_values_wm in-place self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) @@ -1024,7 +1651,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) - + # TODO: precompute_pos_emb_diff_kv 与 update_cache_context 的硬编码不匹配,collect_env_num=1应该没有问题 def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -1098,6 +1725,8 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde context_length = self.context_length if not is_init_infer: + + # ============ Internal Node ============ # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding current_max_context_length = max(self.keys_values_wm_size_list_current) @@ -1211,13 +1840,65 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 if is_init_infer: - # Store the latest key-value cache for initial inference + # TODO + # ==================== 主动淘汰修复逻辑 ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index_init_envs[i] + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + # 确保要删除的键确实存在,避免意外错误 + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key + + # 调用调试函数进行检查 + self._debug_check_for_stale_pointers(env_id=i, current_key=cache_key, index_to_be_written=index_to_write) + # ============================================================ + + # Store the latest key-value cache for initial inference + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index else: - # Store the latest key-value cache for recurrent inference + # TODO 获取要存入的cache的某个唯一标识,例如tensor的和 + # cache_to_store = self.keys_values_wm_single_env._keys_values[0]._k_cache._cache + # cache_sum = torch.sum(cache_to_store).item() + # cache_shape = cache_to_store.shape + # print(f"[CACHE WRITE] Storing for key={cache_key}, cache_shape={cache_shape}, cache_sum={cache_sum:.4f}") + + # ==================== RECURRENT INFER FIX ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + + # 5. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key + # ============================================================ + + # ==================== DEBUG CODE INSERTION ==================== + # 调用调试函数进行检查 + self._debug_check_for_stale_pointers_recur(current_key=cache_key, index_to_be_written=index_to_write) + # ============================================================ + + # Store the latest key-value cache for recurrent inference + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, @@ -1237,7 +1918,7 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, - list: Sizes of the key-value caches for each environment. """ for index in range(ready_env_num): - self.total_query_count += 1 + self.recur_total_query_count += 1 state_single_env = latent_state[index] # latent_state[i] is np.array cache_key = hash_state(state_single_env) @@ -1249,20 +1930,50 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) if cache_index is not None: matched_value = self.shared_pool_init_infer[index][cache_index] + + # TODO + # retrieved_cache = matched_value._keys_values[0]._k_cache._cache + # retrieved_sum = torch.sum(retrieved_cache).item() + # retrieved_shape = retrieved_cache.shape + # print(f"[CACHE HIT] Found for key={cache_key}, retrieved_shape={retrieved_shape}, retrieved_sum={retrieved_sum:.4f}") + + else: matched_value = None # If not found, try to retrieve from past_kv_cache_recurrent_infer + # if matched_value is None: + # matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + # ==================== 核心修复 ==================== + # 步骤 2: 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找 if matched_value is None: - matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + # 2.1 安全地从字典中获取索引,它可能返回 None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # 2.2 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值 + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") + + # ================================================= + # # TODO + # retrieved_cache = matched_value._keys_values[0]._k_cache._cache + # retrieved_sum = torch.sum(retrieved_cache).item() + # retrieved_shape = retrieved_cache.shape + # print(f"[CACHE HIT] Found for key={cache_key}, retrieved_shape={retrieved_shape}, retrieved_sum={retrieved_sum:.4f}") + if matched_value is not None: # If a matching cache is found, add it to the lists - self.hit_count += 1 + self.recur_hit_count += 1 # Perform a deep copy because the transformer's forward pass might modify matched_value in-place self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) self.keys_values_wm_size_list.append(matched_value.size) else: + # print(f"[CACHE MISS] Not found for key={cache_key}. Generating new cache.") + # If no matching cache is found, generate a new one using zero reset self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( n=1, max_tokens=self.context_length @@ -1295,6 +2006,23 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Encode observations into latent state representations obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) + # # ======================= 在这里插入分析代码 ======================= + # # 从kwargs获取全局step,假设您在训练循环中传入了它 + global_step = kwargs.get('global_step', 0) + # current_policy_label_eps = kwargs.get('current_policy_label_eps', 0) + current_policy_label_eps = kwargs["current_policy_label_eps"] + + + # # 为了避免影响训练,可以控制调用频率 + # if global_step % 10 == 0: # 每100个training step分析一次 + # self._analyze_latent_representation( + # latent_states=obs_embeddings, + # timesteps=batch['timestep'], + # game_states=batch['observations'], # 传入原始图像 + # step_counter=global_step + # ) + # # ================================================================= + # ========= for visual analysis ========= # Uncomment the lines below for visual analysis in Pong # self.plot_latent_tsne_each_and_all_for_pong(obs_embeddings, suffix='pong_H10_H4_tsne') @@ -1317,18 +2045,111 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar else: dormant_ratio_encoder = torch.tensor(0.) - # Calculate the L2 norm of the latent state roots - latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() - # Action tokens if self.continuous_action_space: act_tokens = batch['actions'] else: act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + with torch.no_grad(): + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + # Calculate the L2 norm of the latent action + latent_action_l2_norms = torch.norm(self.act_embedding_table(act_tokens), p=2, dim=2).mean() + + if self.config.latent_norm_loss: + # ==================== L2惩罚损失计算(最终修复版 v2) ==================== + # 1. 计算每个 latent_state 向量的L2范数的平方。 + # 根据调试信息,obs_embeddings shape: (B*L, 1, E) + # 所以 latent_norm_sq shape: (B*L, 1) + latent_norm_sq = torch.norm(obs_embeddings, p=2, dim=-1).pow(2) + # 2. 获取源掩码。 + # 根据调试信息,mask_source shape: (B, L) + mask_source = batch['mask_padding'] + # 3. 将源掩码从 (B, L) reshape 为 (B*L, 1),以匹配 latent_norm_sq 的形状。 + # 这是解决维度不匹配错误的关键。 + # 我们使用 view(-1, 1) 来实现这个变形。 + correct_mask = mask_source.contiguous().view(-1, 1) + # 4. 检查变形后的形状是否匹配。 + # 这是一个防御性编程,确保两个张量的第一个维度是相同的。 + if latent_norm_sq.shape[0] != correct_mask.shape[0]: + # 如果形状不匹配,打印错误信息并抛出异常,这能帮助我们更快地定位未来可能出现的新问题。 + raise RuntimeError( + f"Shape mismatch for L2 norm loss calculation! " + f"latent_norm_sq shape: {latent_norm_sq.shape}, " + f"but correct_mask shape after reshape is: {correct_mask.shape}. " + f"Original mask_source shape was: {mask_source.shape}" + ) + # 5. 直接进行逐元素乘法。因为现在它们的形状都是 (B*L, 1),所以可以安全相乘。 + masked_latent_norm_sq = latent_norm_sq * correct_mask + # 6. 计算平均损失。分母是掩码中所有“1”的总和,代表有效的元素数量。 + # 增加一个极小值 epsilon (1e-8) 防止分母为零。 + latent_norm_loss = masked_latent_norm_sq.sum() / (correct_mask.sum() + 1e-8) + # ================================================================= + else: + latent_norm_loss = torch.tensor(0.) + + # Forward pass to obtain predictions for observations, rewards, and policies outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + + + # [新增] 从模型输出中获取中间张量 x,并分离计算图 + intermediate_tensor_x = outputs.output_sequence.detach() + + # TODO============ + # ======================= 在这里插入分析代码 ======================= + # if global_step > 0 and global_step % 1000 == 0: + # if global_step > 0 and global_step % 5000 == 0: + # if global_step >= 0 and global_step % 5000 == 0: # 5k + + if global_step >= 0 and global_step % 10000 == 0: # 20k + with torch.no_grad(): + # 将logits转换为标量值 + # 注意:outputs的形状是(B, L, E),我们需要reshape + batch_size, seq_len = batch['actions'].shape[0], batch['actions'].shape[1] + + pred_val_logits = outputs.logits_value.view(batch_size * seq_len, -1) + pred_rew_logits = outputs.logits_rewards.view(batch_size * seq_len, -1) + + scalar_values = inverse_scalar_transform_handle(pred_val_logits).squeeze(-1) + scalar_rewards = inverse_scalar_transform_handle(pred_rew_logits).squeeze(-1) + + self._analyze_latent_representation( + latent_states=obs_embeddings, + timesteps=batch['timestep'], + game_states=batch['observations'], + predicted_values=scalar_values, # 传入预测的Value + predicted_rewards=scalar_rewards, # 传入预测的Reward + step_counter=global_step + ) + + # ================================================================= + + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + if self.obs_type == 'image': # Reconstruct observations from latent state representations # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) @@ -1434,9 +2255,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') # ========== for visualization ========== - # For training stability, use target_tokenizer to compute the true next latent state representations with torch.no_grad(): + # For training stability, use target_tokenizer to compute the true next latent state representations target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations']) + # target_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) # Compute labels for observations, rewards, and ends labels_observations, labels_rewards, _ = self.compute_labels_world_model(target_obs_embeddings, @@ -1474,10 +2296,15 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar loss_obs = (loss_obs * mask_padding_expanded) # Compute labels for policy and value - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + labels_value, labels_policy = self.compute_labels_world_model_value_policy(batch['target_value'], batch['target_policy'], batch['mask_padding']) + # --- NEW: Apply label smoothing to policy target --- + if current_policy_label_eps > 0: + # Assumes target_policy is a probability distribution (sums to 1) + labels_policy = (1.0 - current_policy_label_eps) * labels_policy + current_policy_label_eps / self.action_space_size + # Compute losses for rewards, policy, and value loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') @@ -1556,6 +2383,21 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + # 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典 + # 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算 + detached_obs_embeddings = obs_embeddings.detach() + + # 在 return 语句之前,获取当前的温度值 + temp_value, temp_reward, temp_policy = 1.0, 1.0, 1.0 + if self.use_temperature_scaling: + # temp_value = torch.clamp(self.log_temp_value.exp(), min=0.1).item() + # temp_reward = torch.clamp(self.log_temp_reward.exp(), min=0.1).item() + # temp_policy = torch.clamp(self.log_temp_policy.exp(), min=0.1).item() + + temp_value = 1.0 + F.softplus(self.log_temp_value).item() + temp_reward = 1.0 + F.softplus(self.log_temp_reward).item() + temp_policy = 1.0 + F.softplus(self.log_temp_policy).item() + if self.continuous_action_space: return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, @@ -1575,9 +2417,24 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dormant_ratio_encoder=dormant_ratio_encoder, dormant_ratio_world_model=dormant_ratio_world_model, latent_state_l2_norms=latent_state_l2_norms, + latent_action_l2_norms=latent_action_l2_norms, policy_mu=mu, policy_sigma=sigma, target_sampled_actions=target_sampled_actions, + latent_norm_loss=latent_norm_loss, # 新增 + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, # <-- 新增 + + # ==================== 新增日志项 ==================== + temperature_value=temp_value, + temperature_reward=temp_reward, + temperature_policy=temp_policy, + # =================================================== + + # ==================== [修改] 新增监控张量 ==================== + intermediate_tensor_x=intermediate_tensor_x, + # ========================================================== + ) else: return LossWithIntermediateLosses( @@ -1598,6 +2455,31 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dormant_ratio_encoder=dormant_ratio_encoder, dormant_ratio_world_model=dormant_ratio_world_model, latent_state_l2_norms=latent_state_l2_norms, + latent_action_l2_norms=latent_action_l2_norms, + latent_norm_loss=latent_norm_loss, # 新增 + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, # <-- 新增 + + logits_value_mean=outputs.logits_value.mean(), + logits_value_max=outputs.logits_value.max(), + logits_value_min=outputs.logits_value.min(), + + logits_policy_mean=outputs.logits_policy.mean(), + logits_policy_max=outputs.logits_policy.max(), + logits_policy_min=outputs.logits_policy.min(), + logits_value=outputs.logits_value.detach(), # 使用detach(),因为它仅用于分析和裁剪,不参与梯度计算 + logits_reward=outputs.logits_rewards.detach(), + logits_policy=outputs.logits_policy.detach(), + + # ==================== 新增日志项 ==================== + temperature_value=temp_value, + temperature_reward=temp_reward, + temperature_policy=temp_policy, + # =================================================== + + # ==================== [修改] 新增监控张量 ==================== + intermediate_tensor_x=intermediate_tensor_x, + # ========================================================== ) @@ -1742,12 +2624,56 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): logits = getattr(outputs, f'logits_{element}') + # 10.0 是一个经验值,可以调整。它足以产生非常尖锐的分布,但不会到-18那么夸张。 + # logits = torch.clamp(logits, min=-10.0, max=10.0) # TODO + if torch.isnan(logits).any(): raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") if torch.isnan(labels).any(): raise ValueError(f"NaN detected in labels_value for batch {batch} and element '{element}'") + # TODO + # # ==================== 核心修复:温度缩放 ==================== + # # 仅对 value 和 reward 应用,因为 policy 的目标已经是软的 + # if element in ['value', 'reward']: + # temperature = 2.0 # 这是一个可以调整的超参数,可以从1.5或2.0开始 + # logits = logits / temperature + # # ============================================================= + + # ==================== START: APPLY LEARNABLE TEMPERATURE ==================== + # if self.use_temperature_scaling: + # # 根据 element 类型,选择对应的温度参数 + # if element == 'value': + # # T = self.log_temp_value.exp() + # # 为了防止T过小导致数值不稳定,可以加一个clamp + # T = torch.clamp(self.log_temp_value.exp(), min=0.1) + # elif element == 'rewards': + # # T = self.log_temp_reward.exp() + # T = torch.clamp(self.log_temp_reward.exp(), min=0.1) + # elif element == 'policy': + # # T = self.log_temp_policy.exp() + # T = torch.clamp(self.log_temp_policy.exp(), min=0.1) + # else: + # T = 1.0 # 对于其他未知类型,不使用温度 + + # # 应用温度缩放:用温度 T 来除以 logits + # # 增加一个极小值防止除以零(尽管exp()保证了T>0) + # logits = logits / (T + 1e-8) + if self.use_temperature_scaling: + if element == 'value': + T = 1.0 + F.softplus(self.log_temp_value) + elif element == 'rewards': + T = 1.0 + F.softplus(self.log_temp_reward) + elif element == 'policy': + T = 1.0 + F.softplus(self.log_temp_policy) + else: + T = 1.0 + + logits = logits / (T + 1e-8) + # ===================== END: APPLY LEARNABLE TEMPERATURE ===================== + + # Reshape your tensors logits = rearrange(logits, 'b t e -> (b t) e') labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] @@ -1758,6 +2684,25 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): # Compute cross-entropy loss loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) loss = (loss * mask_padding) + + # TODO: check + # 3. 计算每个token的损失 + # 当 labels 是浮点数(概率)时,CrossEntropyLoss 会自动计算交叉熵 + # loss_per_token = self.ce_loss_fn_pt(logits, labels) + # # 4. 应用掩码并计算最终的平均损失 + # # loss.sum() / mask_padding.sum() 是计算有效token的平均损失,这是最标准的做法 + # loss = (loss_per_token * mask_padding) + + + # TODO===== + # # --- Calculate policy loss using the smoothed target --- + # # Use KL-Divergence for probability targets, which is equivalent to CrossEntropy for one-hot + # # Log-softmax on logits and KLDiv is more numerically stable than Softmax + CrossEntropy + # log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + # # The target for KLDiv should be probabilities, not log-probabilities + # loss = torch.nn.functional.kl_div(log_probs labels, reduction='batchmean') + + if torch.isnan(loss).any(): raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") @@ -1766,10 +2711,12 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): # Compute policy entropy loss policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) # Combine losses with specified weight + # print(f"self.policy_entropy_weight:{self.policy_entropy_weight}") combined_loss = loss - self.policy_entropy_weight * policy_entropy return combined_loss, loss, policy_entropy return loss + def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy @@ -1813,9 +2760,9 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta labels_value = target_value.masked_fill(mask_fill_value, -100) if self.continuous_action_space: - return None, labels_value.reshape(-1, self.support_size) + return labels_value.reshape(-1, self.support_size), None else: - return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + return labels_value.reshape(-1, self.support_size), labels_policy.reshape(-1, self.action_space_size) def clear_caches(self): """ diff --git a/lzero/model/unizero_world_models/world_model_bkp20250819_v3.py b/lzero/model/unizero_world_models/world_model_bkp20250819_v3.py new file mode 100644 index 000000000..a780aa890 --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_bkp20250819_v3.py @@ -0,0 +1,1907 @@ +import logging +from typing import Dict, Union, Optional, List, Tuple, Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform + +from lzero.model.common import SimNorm, L2Norm +from lzero.model.utils import cal_dormant_ratio +from .kv_caching import KeysValues +from .slicer import Head, PolicyHeadCont +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state + +logging.getLogger().setLevel(logging.DEBUG) + + +class WorldModel(nn.Module): + """ + Overview: + The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), + which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + """ + + def __init__(self, config: TransformerConfig, tokenizer) -> None: + """ + Overview: + Initialize the WorldModel class. + Arguments: + - config (:obj:`TransformerConfig`): The configuration for the transformer. + - tokenizer (:obj:`Tokenizer`): The tokenizer. + """ + super().__init__() + self.tokenizer = tokenizer + self.config = config + self.transformer = Transformer(self.config) + + if self.config.device == 'cpu': + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Move all modules to the specified device + logging.info(f"self.device: {self.device}") + self.to(self.device) + + # Initialize configuration parameters + self._initialize_config_parameters() + + # Initialize patterns for block masks + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + # Position embedding + if not self.config.rotary_emb: + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + # TODO(pu) + # self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device, max_norm=1.0) + self.precompute_pos_emb_diff_kv() + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + + self.continuous_action_space = self.config.continuous_action_space + + # Initialize action embedding table + if self.continuous_action_space: + # TODO: check the effect of SimNorm + self.act_embedding_table = nn.Sequential( + nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size)) + else: + # for discrete action space + self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + # TODO(pu) + # self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device, max_norm=1.0) + + logging.info(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm') + + # Head modules + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ + self._get_final_norm(self.final_norm_option_in_obs_head) # NOTE: using the specified normalization method for observations head + ) + if self.continuous_action_space: + self.sigma_type = self.config.sigma_type + self.bound_type = self.config.bound_type + self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.action_space_size) + else: + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + + # Build the set of modules to skip during re-initialization. + # This is compatible with cases where self.tokenizer.encoder does not have 'pretrained_model', + # or self.tokenizer does not have 'decoder_network'. + # NOTE: This step is crucial — without skipping, pretrained modules (e.g., encoder/decoder) would be unintentionally re-initialized + skip_modules = set() + if hasattr(self.tokenizer.encoder, 'pretrained_model'): + skip_modules.update(self.tokenizer.encoder.pretrained_model.modules()) + if hasattr(self.tokenizer, 'decoder_network'): + if self.tokenizer.decoder_network is not None: + skip_modules.update(self.tokenizer.decoder_network.modules()) + + def custom_init(module): + # If the current module is part of the skip list, return without reinitializing + if module in skip_modules: + return + # Otherwise, apply the specified initialization method + init_weights(module, norm_type=self.config.norm_type) + + # Recursively apply `custom_init` to all submodules of the model + self.apply(custom_init) + + # self.apply(init_weights) + + + self._initialize_last_layer() + + # Cache structures + self._initialize_cache_structures() + + # Projection input dimension + self._initialize_projection_input_dim() + + # Hit count and query count statistics + self._initialize_statistics() + + # Initialize keys and values for transformer + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # TODO: check the size of the shared pool + # for self.kv_cache_recurrent_infer + # If needed, recurrent_infer should store the results of the one MCTS search. + self.num_simulations = getattr(self.config, 'num_simulations', 50) # TODO + self.shared_pool_size = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size + self.shared_pool_index = 0 + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2*self.env_num) + self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # for self.kv_cache_wm + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + + def _get_final_norm(self, norm_option: str) -> nn.Module: + """ + Return the corresponding normalization module based on the specified normalization option. + """ + if norm_option == 'LayerNorm': + return nn.LayerNorm(self.config.embed_dim, eps=1e-5) + elif norm_option == 'SimNorm': + return SimNorm(simnorm_dim=self.config.group_size) + elif norm_option == 'L2Norm': + # L2Norm 是一个函数式操作,不需要在 init 中定义模块 + return L2Norm(eps=1e-6) + else: + raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") + + def custom_copy_kv_cache_to_shared_init_envs(self, src_kv: KeysValues, env_id) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for a specific environment in the init_infer stage. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + - env_id (:obj:`int`): The identifier of the environment for which the cache is being copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] is None: + self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index_init_envs[env_id] + self.shared_pool_index_init_envs[env_id] = (self.shared_pool_index_init_envs[env_id] + 1) % self.shared_pool_size_init + + return index + + def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for world model usage. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_wm[self.shared_pool_index_wm] is None: + self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_wm[self.shared_pool_index_wm] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + self.shared_pool_index_wm = (self.shared_pool_index_wm + 1) % self.shared_pool_size_wm + + return dst_kv + + def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for recurrent inference. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_recur_infer[self.shared_pool_index] is None: + self.shared_pool_recur_infer[self.shared_pool_index] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_recur_infer[self.shared_pool_index] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + + return index + + def _initialize_config_parameters(self) -> None: + """Initialize configuration parameters.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.env_num = self.config.env_num + self.num_layers = self.config.num_layers + self.obs_per_embdding_dim = self.config.embed_dim + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + + def _initialize_patterns(self) -> None: + """Initialize patterns for block masks.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_cont(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + from ding.model.common import ReparameterizationHead + self.fc_policy_head = ReparameterizationHead( + input_size=self.config.embed_dim, + output_size=output_dim, + layer_num=2, # TODO: check the effect of layer_num + sigma_type=self.sigma_type, + activation=nn.GELU(approximate='tanh'), + fixed_sigma_value=self.config.fixed_sigma_value if self.sigma_type == 'fixed' else 0.5, + norm_type=None, + bound_type=self.bound_type + ) + return PolicyHeadCont( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=self.fc_policy_head + ) + + def _initialize_last_layer(self) -> None: + """Initialize the last linear layer.""" + last_linear_layer_init_zero = True # TODO + if last_linear_layer_init_zero: + if self.continuous_action_space: + module_to_initialize = [self.head_value, self.head_rewards, self.head_observations] + else: + module_to_initialize = [self.head_policy, self.head_value, self.head_rewards, self.head_observations] + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + from collections import defaultdict + self.past_kv_cache_recurrent_infer = defaultdict(dict) + self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initialize the projection input dimension based on the number of observation tokens.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + self.projection_input_dim = self.obs_per_embdding_dim + + def _initialize_statistics(self) -> None: + """Initialize counters for hit count and query count statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + def _initialize_transformer_keys_values(self) -> None: + """Initialize keys and values for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm_single_env_tmp = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, + max_tokens=self.context_length) + + def precompute_pos_emb_diff_kv(self): + """ Precompute positional embedding differences for key and value. """ + if self.context_length <= 2: + # If context length is 2 or less, no context is present + return + # Precompute positional embedding matrices for inference in collect/eval stages, not for training + self.positional_embedding_k = [ + self._get_positional_embedding(layer, 'key') + for layer in range(self.config.num_layers) + ] + self.positional_embedding_v = [ + self._get_positional_embedding(layer, 'value') + for layer in range(self.config.num_layers) + ] + + # Precompute all possible positional embedding differences + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + if torch.cuda.is_available(): + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).to(self.device).detach() + else: + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).detach() + + def forward( + self, + obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, Tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, + is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, + start_pos: Union[int, List[int]] = 0, + search_depth: Optional[List[int]] = None + ) -> "WorldModelOutput": + """ + Overview: + Forward pass for the world model. This method processes observation embeddings and/or action tokens, + optionally adds position encodings (with or without rotary position embeddings), passes the resulting + sequences through the transformer, and finally generates logits for observations, rewards, policy, and value. + + Arguments: + - obs_embeddings_or_act_tokens (dict): Dictionary containing one or more of the following keys: + - 'obs_embeddings': torch.Tensor representing observation embeddings. + - 'act_tokens': torch.Tensor representing action tokens. + - 'obs_embeddings_and_act_tokens': Combined data for both observations and actions. + - past_keys_values (Optional[torch.Tensor]): Cached key-value pairs for the transformer. Defaults to None. + - kvcache_independent (bool): Flag to indicate whether key-value caching is independent. Defaults to False. + - is_init_infer (bool): Flag to indicate if this is the initial inference step. Defaults to True. + - valid_context_lengths (Optional[torch.Tensor]): Valid lengths for the context. Defaults to None. + - start_pos (int or List[int]): Starting positional index for the current sequence (or batch). Defaults to 0. + - search_depth (Optional[List[int]]): List representing the search depth for each batch element, used for + position encoding adjustment. Defaults to None. + + Returns: + WorldModelOutput: An output instance containing: + - x: Output features from the transformer. + - logits for observations. + - logits for rewards. + - logits_ends (None). + - logits for policy. + - logits for value. + """ + + # Calculate previous steps based on key-value caching configuration + if kvcache_independent: + # If kv caching is independent, compute previous steps for each past key-value pair. + prev_steps = torch.tensor( + [0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], + device=self.device + ) + else: + # Otherwise, use a single value for previous steps. + prev_steps = 0 if past_keys_values is None else past_keys_values.size + + # Reset valid context lengths during initial inference phase. + if is_init_infer: + valid_context_lengths = None + + # sequences: torch.Tensor # Output sequence to feed into transformer + # num_steps: int # Number of timesteps in the sequence + # start_pos_adjusted: Union[int, List[int]] # Adjusted starting position index for positional encoding + + if not self.config.rotary_emb: + start_pos_adjusted = None + + # Process observation embeddings if available. + if "obs_embeddings" in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens["obs_embeddings"] + # If the observation embeddings have 2 dimensions, expand them to include a time dimension. + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + num_steps = obs_embeddings.size(1) + + if not self.config.rotary_emb: + # Add traditional position embeddings if not using rotary embeddings. + sequences = self._add_position_embeddings( + obs_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths + ) + else: + # Keep the observation embeddings unchanged when using rotary embeddings. + sequences = obs_embeddings + + if is_init_infer: + if self.reanalyze_phase: + # During reanalyze phase in initial inference, adjust start_pos: + # Multiply by 2 because timestep only counts observations, + # but the sequence contains both observations and actions. + start_pos_adjusted = start_pos * 2 + if not isinstance(start_pos_adjusted, (int, float)): + # Pad zero if start_pos_adjusted is not a scalar. + padding = np.zeros((start_pos_adjusted.shape[0], 1), dtype=start_pos_adjusted.dtype) + start_pos_adjusted = np.concatenate([start_pos_adjusted, padding], axis=1).reshape(-1) + else: + # For regular initial inference, adjust start_pos accordingly. + if isinstance(start_pos, (int, float)): + start_pos_adjusted = start_pos * 2 + else: + start_pos_adjusted = [pos * 2 for pos in start_pos] + else: + # For recurrent inference (non-init), calculate the correct positional index. + if self.reanalyze_phase: + # In reanalyze phase, start_pos for batch mode might be an array that needs padding. + if not isinstance(start_pos, (int, float)): + padding = np.zeros((start_pos.shape[0], 1), dtype=start_pos.dtype) + start_pos_adjusted = np.concatenate([start_pos, padding], axis=1).reshape(-1) + # Ensure search_depth length matches adjusted start_pos. + assert len(search_depth) == len(start_pos_adjusted) + start_pos_adjusted = [ + (search_depth[i] + pos + 1) * 2 + 1 for i, pos in enumerate(start_pos_adjusted) + ] + else: + start_pos_adjusted = [ + (search_depth[i] + pos) * 2 + 2 for i, pos in enumerate(start_pos) + ] + + # Process action tokens if available. + elif "act_tokens" in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens["act_tokens"] + if self.continuous_action_space: + num_steps = 1 + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(1) + else: + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + # Convert action tokens to embeddings using the action embedding table. + act_embeddings = self.act_embedding_table(act_tokens) + if not self.config.rotary_emb: + sequences = self._add_position_embeddings( + act_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths + ) + else: + sequences = act_embeddings + + if is_init_infer: + if self.reanalyze_phase: + # In reanalyze phase during initial inference, the action tokens represent the current timestep. + start_pos_adjusted = start_pos * 2 + 1 + if not isinstance(start_pos_adjusted, (int, float)): + padding = np.zeros((start_pos_adjusted.shape[0], 1), dtype=start_pos_adjusted.dtype) + start_pos_adjusted = np.concatenate([start_pos_adjusted, padding], axis=1).reshape(-1) + else: + # For regular initial inference using action tokens, adjust start_pos by subtracting 1. + if isinstance(start_pos, (int, float)): + start_pos_adjusted = start_pos * 2 - 1 + else: + start_pos_adjusted = [pos * 2 - 1 for pos in start_pos] + else: + # During recurrent inference for action tokens. + if self.reanalyze_phase: + if not isinstance(start_pos, (int, float)): + padding = np.zeros((start_pos.shape[0], 1), dtype=start_pos.dtype) + start_pos_adjusted = np.concatenate([start_pos, padding], axis=1).reshape(-1) + assert len(search_depth) == len(start_pos_adjusted) + start_pos_adjusted = [ + (search_depth[i] + pos + 1) * 2 + 1 for i, pos in enumerate(start_pos_adjusted) + ] + else: + start_pos_adjusted = [ + (search_depth[i] + pos) * 2 + 1 for i, pos in enumerate(start_pos) + ] + + # Process combined observation embeddings and action tokens. + elif "obs_embeddings_and_act_tokens" in obs_embeddings_or_act_tokens: + # Process combined inputs to calculate either the target value (for training) + # or target policy (for reanalyze phase). + if self.continuous_action_space: + sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps) + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + # Adjust start positions: multiply by 2 as the sequence has both obs and act. + start_pos_adjusted = [pos * 2 for pos in start_pos] + else: + raise ValueError("Input dictionary must contain one of 'obs_embeddings', 'act_tokens', or 'obs_embeddings_and_act_tokens'.") + + # Pass the sequence through the transformer. + x = self._transformer_pass( + sequences, past_keys_values, kvcache_independent, valid_context_lengths, start_pos=start_pos_adjusted + ) + + # Generate logits for various components. + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + + # The 'logits_ends' is intentionally set to None. + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, + valid_context_lengths): + """ + Add position embeddings to the input embeddings. + + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + - num_steps (:obj:`int`): Number of steps. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Embeddings with position information added. + """ + if kvcache_independent: + steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + return embeddings + position_embeddings + else: + if is_init_infer: + return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + else: + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + position_embeddings = self.pos_emb( + valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + return embeddings + position_embeddings + + def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + if self.continuous_action_space: + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: # TODO + act_tokens = act_tokens.unsqueeze(-1) + + # B, L, E + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + act = act_embeddings[:, i, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return_result = obs_act_embeddings + if not self.config.rotary_emb: + return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + return return_result, num_steps + + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + act = act_embeddings[:, i, 0, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return_result = obs_act_embeddings + if not self.config.rotary_emb: + return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + return return_result, num_steps + + def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths, start_pos: int = 0): + """ + Pass sequences through the transformer. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Transformer output. + """ + if kvcache_independent: + x = [self.transformer(sequences[k].unsqueeze(0), past_kv, + valid_context_lengths=valid_context_lengths[k].unsqueeze(0), start_pos=start_pos) for k, past_kv in + enumerate(past_keys_values)] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos) + + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor: + """ + Reset the model state based on initial observations and actions. + + Arguments: + - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + Returns: + - torch.FloatTensor: The outputs from the world model and the latent state. + """ + # Extract observations, actions, and current observations from the dictionary. + if isinstance(obs_act_dict, dict): + batch_obs = obs_act_dict['obs'] # obs_act_dict['obs'] is at timestep t + batch_action = obs_act_dict['action'] # obs_act_dict['action'] is at timestep t + batch_current_obs = obs_act_dict['current_obs'] # obs_act_dict['current_obs'] is at timestep t+1 + + # Encode observations to latent embeddings. + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs) + + if batch_current_obs is not None: + # ================ Collect and Evaluation Phase ================ + # Encode current observations to latent embeddings + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + self.latent_state = current_obs_embeddings + outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, + current_obs_embeddings, start_pos) + else: + # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ + self.latent_state = obs_embeddings + outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, None, start_pos) + + return outputs_wm, self.latent_state + + @torch.no_grad() + def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, start_pos: int = 0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - last_obs_embeddings (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + # Determine whether it is the first step in an episode. + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # ------------------------- First Step of an Episode ------------------------- + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, start_pos=start_pos) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # --------------------- Continuing an Episode (Multi-environment) --------------------- + # current_obs_embeddings is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + for i in range(ready_env_num): + # Retrieve latent state for a single environment + # TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done + + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state(state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # NOTE: deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + # If using RoPE positional encoding, then at reset, the pos_embed should use the absolute position start_pos[i]. + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, start_pos=start_pos[i].item()) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + start_pos = start_pos[:ready_env_num] + # TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done + # TODO: the order may be not correct? len(batch_action) may smaller than len(current_obs_embeddings), because some environments may have done + batch_action = batch_action[:ready_env_num] + + # TODO: only for debug + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + # print(f"ready_env_num: {ready_env_num}") + # print(f"start_pos: {start_pos}") + # print(f"batch_action: {batch_action}") + # print(f"len(last_obs_embeddings): {len(last_obs_embeddings)}") + # print(f"len(batch_action): {len(batch_action)}") + # print(f"len(current_obs_embeddings): {len(current_obs_embeddings)}") + + if self.continuous_action_space: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) + else: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, start_pos=start_pos) + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, start_pos=start_pos) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + elif batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + if self.continuous_action_space: + act_tokens = batch_action + else: + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + # Each sample in the batch (last_obs_embeddings, act_tokens) corresponds to the same time step, and start_pos also corresponds to each sample's respective t. + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, start_pos=start_pos) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, start_pos) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + search_depth=[], start_pos: int = 0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - search_depth (:obj:`list`, optional): List containing depth of latent states in the search tree. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, start_pos) + + latent_state_list = [] + + # if not self.continuous_action_space: + # token = action.reshape(-1, 1) + # else: + # token = action.reshape(-1, self.action_space_size) + + # TODO(pu, 20250819): only for toy env =================== + if not self.continuous_action_space: + # action 是 numpy 数组, nn.Embedding 期望 LongTensor + token = torch.from_numpy(action.reshape(-1, 1)).long().to(self.device) + else: + # action 是 numpy 数组, nn.Linear 期望 FloatTensor + token = torch.from_numpy(action.reshape(-1, self.action_space_size)).float().to(self.device) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache: modify self.keys_values_wm in-place + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + start_pos=start_pos, + search_depth=search_depth # List containing depth of latent states in the search tree. + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + # world_model.py -> 替换整个 update_cache_context 函数 + # TODO(pu, 20250819) + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + search_depth=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + This is the FINAL corrected version for a true sliding window. + """ + if self.context_length <= 2: + return + + # 定义目标窗口大小,即我们希望缓存滑动的最终长度 + # 我们保留 context_length - 1 的长度,为下一步的单个token预测留出空间 + # TARGET_WINDOW_SIZE = self.context_length - 1 + TARGET_WINDOW_SIZE = self.context_length - 3 + + for i in range(latent_state.size(0)): + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) + + # 步骤 1: 从批量缓存中提取单个环境的缓存 (与上次修复相同) + if not is_init_infer: + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + if trim_size > 0: + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm_size_list_current[i] + else: + for layer in range(self.num_layers): + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + + # 步骤 2: 应用真正的滑动窗口截断逻辑 + for layer in range(self.num_layers): + current_size = self.keys_values_wm_single_env._keys_values[layer]._k_cache._size + + # 只有当当前大小超过我们的目标窗口大小时,才进行截断 + if current_size > TARGET_WINDOW_SIZE: + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # 计算需要丢弃的旧token数量 + tokens_to_discard = current_size - TARGET_WINDOW_SIZE + + # 从头部丢弃旧token,保留最新的部分 + k_cache_trimmed = k_cache_current[:, :, tokens_to_discard:current_size, :] + v_cache_trimmed = v_cache_current[:, :, tokens_to_discard:current_size, :] + + new_size = k_cache_trimmed.shape[2] # new_size 现在是 TARGET_WINDOW_SIZE + + # !! 警告: 原始的位置编码校正可能不再精确 !! + # 原始的校正是为“丢弃2个”而设计的。现在我们动态丢弃,这个校正会不匹配。 + # 对于RoPE,这不成问题。对于非RoPE,这会引入一些误差,但通常好于不滑动。 + # 在此,我们选择不应用可能错误的校正,以支持正确的滑动行为。 + if not self.config.rotary_emb: + # 对截断后的部分应用位置编码校正 + # 注意:这里的实现假设了固定的截断大小,对于动态截断可能需要更复杂的处理 + # 但对于当前逻辑是正确的 + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, 2 + new_size)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, 2 + new_size)] + k_cache_trimmed += pos_emb_diff_k + v_cache_trimmed += pos_emb_diff_v + + # 用0填充到原始缓存张量的长度,以保持形状一致 + padding_size = k_cache_current.shape[2] - new_size + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, padding_size), 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, padding_size), 'constant', 0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = new_size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = new_size + + # 步骤 3: 存储处理好的缓存 (与上次修复相同) + if is_init_infer: + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + else: + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + # TODO(pu, 20250819) + # def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + # search_depth=[], valid_context_lengths=None): + # """ + # Update the cache context with the given latent state. + # This version contains the corrected truncation logic. + # """ + # if self.context_length <= 2: + # # No context to update if the context length is less than or equal to 2. + # return + + # context_length = self.context_length + + # for i in range(latent_state.size(0)): + # # ============ Iterate over each environment ============ + # cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) + + # # 1. 首先,从全局批量缓存 (self.keys_values_wm) 中提取单个环境的缓存 + # # 并放入 self.keys_values_wm_single_env,同时处理好多环境推理中可能存在的 padding + # if not is_init_infer: + # # For Internal Nodes, we need to handle potential padding from batching + # current_max_context_length = max(self.keys_values_wm_size_list_current) + # trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + # for layer in range(self.num_layers): + # k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + # v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # if trim_size > 0: + # k_cache_trimmed = k_cache_current[:, trim_size:, :] + # v_cache_trimmed = v_cache_current[:, trim_size:, :] + # k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + # v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + # else: + # k_cache_padded = k_cache_current + # v_cache_padded = v_cache_current + + # self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + # self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm_size_list_current[i] + # self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm_size_list_current[i] + # else: + # # For Root Nodes, the batch is aligned, so we can just extract the slice + # for layer in range(self.num_layers): + # self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) + # self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + # self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + # self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + + # # 2. 其次,对提取出的单个环境缓存应用统一的截断逻辑 + # for layer in range(self.num_layers): + # # ============ NOTE: This is the corrected and unified truncation logic ============ + # current_size = self.keys_values_wm_single_env._keys_values[layer]._k_cache._size + + # # 只有当缓存大小达到或超过截断阈值时才进行操作 + # # 阈值 context_length - 1 是因为一个 (obs, act) 对会增加2个token,我们需要提前截断以留出空间 + # if current_size >= context_length - 1: + # k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + # v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # # 从第2个token开始,保留最新的 context_length-3 个token + # # 这实现了滑动窗口:丢弃最旧的2个token + # k_cache_trimmed = k_cache_current[:, :, 2:current_size, :] + # v_cache_trimmed = v_cache_current[:, :, 2:current_size, :] + + # new_size = k_cache_trimmed.shape[2] + + # if not self.config.rotary_emb: + # # 对截断后的部分应用位置编码校正 + # # 注意:这里的实现假设了固定的截断大小,对于动态截断可能需要更复杂的处理 + # # 但对于当前逻辑是正确的 + # pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, 2 + new_size)] + # pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, 2 + new_size)] + # k_cache_trimmed += pos_emb_diff_k + # v_cache_trimmed += pos_emb_diff_v + + # # 用0填充到原始缓存张量的长度,以保持形状一致 + # padding_size = k_cache_current.shape[2] - new_size + # k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, padding_size), 'constant', 0) + # v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, padding_size), 'constant', 0) + + # self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded + # self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded + # self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = new_size + # self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = new_size + + # # 3. 最后,将处理好的单环境缓存存入相应的池中 + # if is_init_infer: + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + # else: + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, start_pos: int = 0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for index in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[index] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[index][cache_index] + else: + matched_value = None + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + # ==================== TODO(pu): only for debug ==================== + # matched_value = None # Force cache miss + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + + # Determine the absolute start position based on the reanalyze phase flag. + if self.reanalyze_phase: + num_rows, num_cols = start_pos.shape # Original start_pos shape is (batch, num_columns) + total_cols = num_cols + 1 # Each logical row is extended by one column. + row_idx = index // total_cols + col_idx = index % total_cols + # If the column index equals the original number of columns, this indicates the added column; set to 0. + start_pos_adjusted: int = 0 if col_idx == num_cols else int(start_pos[row_idx, col_idx]) + else: + start_pos_adjusted = int(start_pos[index].item()) + + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, start_pos=start_pos_adjusted + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, + **kwargs: Any) -> LossWithIntermediateLosses: + start_pos = batch['timestep'] + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) + + # ========= for visual analysis ========= + # Uncomment the lines below for visual analysis in Pong + # self.plot_latent_tsne_each_and_all_for_pong(obs_embeddings, suffix='pong_H10_H4_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='pong_H10_H4_tsne') + # Uncomment the lines below for visual analysis in visual match + # self.plot_latent_tsne_each_and_all(obs_embeddings, suffix='visual_match_memlen1-60-15_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') + + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), + percentage=self.dormant_threshold) + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + with torch.no_grad(): + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + # Calculate the L2 norm of the latent action + latent_action_l2_norms = torch.norm(self.act_embedding_table(act_tokens), p=2, dim=2).mean() + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # ========== Calculate reconstruction loss and perceptual loss ============ + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = self.latent_recon_loss + + elif self.obs_type == 'text': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=torch.float32) + decode_loss_mode = self.config.decode_loss_mode + + # Reconstruction loss for predicting the next latent (via backbone) + # input -> encoder -> backbone(unizero) -> decoder -> latent_recon_loss + if decode_loss_mode == "after_backbone": + next_latent_state = outputs.logits_observations[:, :-1, :] + next_target_ids = batch['observations'][:, 1:, :] + + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=next_latent_state, + target_ids=next_target_ids, + ).loss + + #Reconstruction loss for predicting the current latent (without using the backbone) + # input -> encoder -> decoder -> latent_recon_loss + elif decode_loss_mode == "before_backbone": + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=obs_embeddings, + target_ids=batch['observations'], + ).loss + + else: + latent_recon_loss = self.latent_recon_loss + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = cal_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + percentage=self.dormant_threshold) + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_world_model = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations']) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, _ = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # print('loss_obs:', loss_obs.mean()) + # assert not torch.isnan(loss_obs).any(), "loss_obs contains NaN values" + # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" + # for name, param in self.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_value, labels_policy = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + + if not self.continuous_action_space: + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, + batch, + element='policy') + else: + # NOTE: for continuous action space + if self.config.policy_loss_type == 'simple': + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont_simple(outputs, batch) + else: + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont(outputs, batch) + + loss_policy = orig_policy_loss + self.policy_entropy_weight * policy_entropy_loss + policy_entropy = - policy_entropy_loss + + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # ==== TODO: calculate the new priorities for each transition. ==== + # value_priority = L1Loss(reduction='none')(labels_value.squeeze(-1), outputs['logits_value'][:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_timestep = seq_len // 2 + middle_step_mask = mask_padding[:, middle_timestep] + middle_step_losses[loss_name] = loss_tmp[:, middle_timestep][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + latent_action_l2_norms=latent_action_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + latent_action_l2_norms=latent_action_l2_norms, + + ) + + + # TODO: test correctness + def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): + """ + Simplified policy loss calculation for continuous actions. + + Args: + - outputs: Model outputs containing policy logits. + - batch (:obj:`dict`): Batch data containing target policy, mask and sampled actions. + + Returns: + - policy_loss (:obj:`torch.Tensor`): The simplified policy loss. + """ + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size + + # Get the policy logits and batch data + policy_logits_all = outputs.logits_policy + mask_batch = batch['mask_padding'].contiguous().view(-1) + target_policy = batch['target_policy'].contiguous().view(batch_size * num_unroll_steps, -1) + target_sampled_actions = batch['child_sampled_actions'].contiguous().view(batch_size * num_unroll_steps, -1, action_space_size) + + # Flatten for vectorized computation + policy_logits_all = policy_logits_all.view(batch_size * num_unroll_steps, -1) + + # Extract mean and standard deviation from logits + mu, sigma = policy_logits_all[:, :action_space_size], policy_logits_all[:, action_space_size:] + dist = Independent(Normal(mu, sigma), 1) # Create the normal distribution + + # Find the indices of the maximum values in the target policy + target_best_action_idx = torch.argmax(target_policy, dim=1) + + # Select the best actions based on the indices + target_best_action = target_sampled_actions[torch.arange(target_best_action_idx.size(0)), target_best_action_idx] + + # Clip the target actions to prevent numerical issues during arctanh + # target_best_action_clamped = torch.clamp(target_best_action, -1 + 1e-6, 1 - 1e-6) + target_best_action_clamped = torch.clamp(target_best_action, -0.999, 0.999) + target_best_action_before_tanh = torch.arctanh(target_best_action_clamped) + + # Calculate the log probability of the best action + log_prob_best_action = dist.log_prob(target_best_action_before_tanh) + + # Mask the log probability with the padding mask + log_prob_best_action = log_prob_best_action * mask_batch + + # Return the negative log probability as the policy loss (we want to maximize log_prob) + # policy_loss = -log_prob_best_action.mean() + policy_loss = -log_prob_best_action + + policy_entropy = dist.entropy().mean() + policy_entropy_loss = -policy_entropy * mask_batch + # Calculate the entropy of the target policy distribution + non_masked_indices = torch.nonzero(mask_batch).squeeze(-1) + if len(non_masked_indices) > 0: + target_normalized_visit_count = target_policy.contiguous().view(batch_size * num_unroll_steps, -1) + target_dist = Categorical(target_normalized_visit_count[non_masked_indices]) + target_policy_entropy = target_dist.entropy().mean().item() + else: + target_policy_entropy = 0.0 + + return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma + + def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calculate the policy loss for continuous actions. + + Args: + - outputs: Model outputs containing policy logits. + - batch (:obj:`dict`): Batch data containing target policy, mask and sampled actions. + Returns: + - policy_loss (:obj:`torch.Tensor`): The calculated policy loss. + - policy_entropy_loss (:obj:`torch.Tensor`): The entropy loss of the policy. + - target_policy_entropy (:obj:`float`): The entropy of the target policy distribution. + - target_sampled_actions (:obj:`torch.Tensor`): The actions sampled from the target policy. + - mu (:obj:`torch.Tensor`): The mean of the normal distribution. + - sigma (:obj:`torch.Tensor`): The standard deviation of the normal distribution. + """ + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size + + policy_logits_all = outputs.logits_policy + mask_batch = batch['mask_padding'] + child_sampled_actions_batch = batch['child_sampled_actions'] + target_policy = batch['target_policy'] + + # Flatten the unroll step dimension for easier vectorized operations + policy_logits_all = policy_logits_all.view(batch_size * num_unroll_steps, -1) + mask_batch = mask_batch.contiguous().view(-1) + child_sampled_actions_batch = child_sampled_actions_batch.contiguous().view(batch_size * num_unroll_steps, -1, + action_space_size) + + mu, sigma = policy_logits_all[:, :action_space_size], policy_logits_all[:, action_space_size:] + mu = mu.unsqueeze(1).expand(-1, child_sampled_actions_batch.shape[1], -1) + sigma = sigma.unsqueeze(1).expand(-1, child_sampled_actions_batch.shape[1], -1) + dist = Independent(Normal(mu, sigma), 1) + + target_normalized_visit_count = target_policy.contiguous().view(batch_size * num_unroll_steps, -1) + target_sampled_actions = child_sampled_actions_batch + + policy_entropy = dist.entropy().mean(dim=1) + policy_entropy_loss = -policy_entropy * mask_batch + + # NOTE: Alternative way to calculate the log probability of the target actions + # y = 1 - target_sampled_actions.pow(2) + # target_sampled_actions_clamped = torch.clamp(target_sampled_actions, -1 + 1e-6, 1 - 1e-6) + # target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped) + # log_prob = dist.log_prob(target_sampled_actions_before_tanh) + # log_prob = log_prob - torch.log(y + 1e-6).sum(-1) + # log_prob_sampled_actions = log_prob + + base_dist = Normal(mu, sigma) + tanh_transform = TanhTransform() + dist = TransformedDistribution(base_dist, [tanh_transform]) + dist = Independent(dist, 1) + target_sampled_actions_clamped = torch.clamp(target_sampled_actions, -0.999, 0.999) + # assert torch.all(target_sampled_actions_clamped < 1) and torch.all(target_sampled_actions_clamped > -1), "Actions are not properly clamped." + log_prob = dist.log_prob(target_sampled_actions_clamped) + log_prob_sampled_actions = log_prob + + # KL as projector + target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + policy_loss = -torch.sum( + torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 + ) * mask_batch + + # Calculate the entropy of the target policy distribution + non_masked_indices = torch.nonzero(mask_batch).squeeze(-1) + if len(non_masked_indices) > 0: + target_dist = Categorical(target_normalized_visit_count[non_masked_indices]) + target_policy_entropy = target_dist.entropy().mean().item() + else: + target_policy_entropy = 0.0 + + return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma + + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + if torch.isnan(logits).any(): + raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if torch.isnan(labels).any(): + raise ValueError(f"NaN detected in labels_value for batch {batch} and element '{element}'") + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + if torch.isnan(loss).any(): + raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + # print(f"self.policy_entropy_weight:{self.policy_entropy_weight}") + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_endgs = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return labels_value.reshape(-1, self.support_size), None + else: + return labels_value.reshape(-1, self.support_size), labels_policy.reshape(-1, self.action_space_size) + + def clear_caches(self): + """ + Clears the caches of the world model. + """ + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + print(f'Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/model/unizero_world_models/world_model_bkp20250910.py b/lzero/model/unizero_world_models/world_model_bkp20250910.py new file mode 100644 index 000000000..5239ffe6f --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_bkp20250910.py @@ -0,0 +1,2384 @@ +import logging +from typing import Dict, Union, Optional, List, Tuple, Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform + +from lzero.model.common import SimNorm, L2Norm +from lzero.model.utils import cal_dormant_ratio +from .kv_caching import KeysValues +from .slicer import Head, PolicyHeadCont +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state +from collections import OrderedDict +logging.getLogger().setLevel(logging.DEBUG) + +from collections import OrderedDict, defaultdict +import matplotlib.pyplot as plt +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +from sklearn.manifold import TSNE +# In unizero_world_model.py + +import torch +import torch.nn as nn + +# --- HOOK FUNCTION FOR DEBUGGING --- +def print_intermediate_activation_hook(module, input, output): + """ + A PyTorch hook that prints the mean and std of a module's output. + This function will be registered to a specific layer (e.g., the first Linear layer in a Head). + + Args: + module: The module the hook is registered on. + input: The input to the module's forward pass. + output: The output from the module's forward pass. + """ + # output is the tensor we want to inspect + mean = output.mean().item() + std = output.std().item() + # We add the module name for clarity, to know which layer's output we are seeing. + print(f" [HOOK DEBUG] Layer '{module.__class__.__name__}' Output -> mean: {mean:.6f}, std: {std:.6f}") + + + + +class LRUCache(OrderedDict): + """ + 一个固定容量的、遵循LRU(最近最少使用)原则的有序字典。 + 非常适合用于管理与环形缓冲区同步的缓存映射。 + """ + def __init__(self, capacity: int=2): + """ + 初始化LRU缓存。 + 参数: + - capacity (int): 缓存的最大容量。 + """ + self.capacity = capacity + super().__init__() + + def __setitem__(self, key: Any, value: Any) -> None: + """ + 重写设置条目的方法,以实现LRU逻辑。 + """ + # 如果键已存在,先删除旧条目,以确保后续添加时它会成为最新项。 + if key in self: + self.move_to_end(key) + + # 调用父类的方法来实际设置键值对。 + super().__setitem__(key, value) + + # 检查是否超出容量。如果超出,则删除最旧的条目。 + # popitem(last=False) 会移除并返回字典中第一个(最旧的)条目。 + if len(self) > self.capacity: + self.popitem(last=False) + +class WorldModel(nn.Module): + """ + Overview: + The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), + which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + """ + + def __init__(self, config: TransformerConfig, tokenizer) -> None: + """ + Overview: + Initialize the WorldModel class. + Arguments: + - config (:obj:`TransformerConfig`): The configuration for the transformer. + - tokenizer (:obj:`Tokenizer`): The tokenizer. + """ + super().__init__() + self.tokenizer = tokenizer + self.config = config + self.transformer = Transformer(self.config) + + if self.config.device == 'cpu': + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Move all modules to the specified device + logging.info(f"self.device: {self.device}") + self.to(self.device) + + # Initialize configuration parameters + self._initialize_config_parameters() + + # Initialize patterns for block masks + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + # Position embedding + if not self.config.rotary_emb: + # self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + # TODO(pu) + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device, max_norm=1.0) + self.precompute_pos_emb_diff_kv() + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + + self.continuous_action_space = self.config.continuous_action_space + + # Initialize action embedding table + if self.continuous_action_space: + # TODO: check the effect of SimNorm + self.act_embedding_table = nn.Sequential( + nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size)) + else: + # for discrete action space + # self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + # TODO(pu) + self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device, max_norm=1.0) + + logging.info(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm') + + # Head modules + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size, use_norm_in_head=True) # TODO + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ + self._get_final_norm(self.final_norm_option_in_obs_head) # NOTE: using the specified normalization method for observations head + ) + if self.continuous_action_space: + self.sigma_type = self.config.sigma_type + self.bound_type = self.config.bound_type + self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.action_space_size) + else: + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size, use_norm_in_head=True) + + # # ==================== NEW DEBUGGING CODE VIA HOOKS ==================== + # # We will attach our hook to the first Linear layer inside the head_value and head_rewards modules. + # # The head_module is an nn.Sequential, so its layers can be accessed by index. + # # Index 0: First nn.Linear + # # Index 1: nn.GELU + # # Index 2: Second nn.Linear + + # # Get the first linear layer from the sequential module + # first_linear_layer_value = self.head_value.head_module[0] + # first_linear_layer_rewards = self.head_rewards.head_module[0] + + # # Register the forward hook + # print("--- Attaching DEBUG hooks to head_value and head_rewards ---") + # self.value_hook_handle = first_linear_layer_value.register_forward_hook(print_intermediate_activation_hook) + # self.rewards_hook_handle = first_linear_layer_rewards.register_forward_hook(print_intermediate_activation_hook) + + # # NOTE: It's good practice to store the hook handle so you can remove it later if needed, e.g., during evaluation or after debugging. + # # To remove the hook: self.value_hook_handle.remove() + # # ==================================================================== + + + # Build the set of modules to skip during re-initialization. + # This is compatible with cases where self.tokenizer.encoder does not have 'pretrained_model', + # or self.tokenizer does not have 'decoder_network'. + # NOTE: This step is crucial — without skipping, pretrained modules (e.g., encoder/decoder) would be unintentionally re-initialized + skip_modules = set() + if hasattr(self.tokenizer.encoder, 'pretrained_model'): + skip_modules.update(self.tokenizer.encoder.pretrained_model.modules()) + if hasattr(self.tokenizer, 'decoder_network'): + if self.tokenizer.decoder_network is not None: + skip_modules.update(self.tokenizer.decoder_network.modules()) + + def custom_init(module): + # If the current module is part of the skip list, return without reinitializing + if module in skip_modules: + return + # Otherwise, apply the specified initialization method + init_weights(module, norm_type=self.config.norm_type) + + # Recursively apply `custom_init` to all submodules of the model + self.apply(custom_init) + + # self.apply(init_weights) + + self._initialize_last_layer() + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # 先设置为game_segment_length,以保持self.shared_pool_init_infer都是有效的kv + # TODO: 非常重要,应该改为和segment_length一样 + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + # self.shared_pool_size_init = int(20) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # self.shared_pool_size_init = int(200) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + self.num_simulations = getattr(self.config, 'num_simulations', 50) + + # TODO: recur kv pool是否应该分成不同的环境有不同的pool呢 + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + + # self.shared_pool_size_init = int(50) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + self.stale_pointer_detections = 0 + self.stale_pointer_detections_recur = 0 + # Cache structures + self._initialize_cache_structures() + + # Projection input dimension + self._initialize_projection_input_dim() + + # Hit count and query count statistics + self._initialize_statistics() + + # Initialize keys and values for transformer + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # TODO: check the size of the shared pool + # for self.kv_cache_recurrent_infer + # If needed, recurrent_infer should store the results of the one MCTS search. + + # self.shared_pool_size_recur = int(2) + + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur + self.shared_pool_index = 0 + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # TODO + # self.shared_pool_size_init = int(50) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + # TODO: 分析self.env_num>1的情况,不同env之间的相同latent-state hash对应的kv_cache可以公用吗 + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # for self.kv_cache_wm + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + + # 用于t-SNE可视化的计数器 + self.tsne_visualization_step = 0 + + # 用于存储梯度hook的handle + self._grad_hooks = [] + + # ======================= 注册梯度Hooks ======================= + # self.register_gradient_hooks(self.tokenizer.representation_network) + # ============================================================= + + + def register_gradient_hooks(self, model_to_hook: nn.Module): + """ + 递归地为模型中的可学习参数注册梯度hook。 + """ + + def hook_fn(grad): + # 这个hook会在该参数的梯度被计算出来后立即执行 + if grad is not None: + grad_norm = grad.norm().item() + grad_mean = grad.mean().item() + grad_std = grad.std().item() + # 为了避免信息过载,我们可以只打印非零梯度的统计信息 + if grad_norm > 1e-9: + print(f" [GRAD HOOK] Param: {name}, Shape: {grad.shape} | Norm: {grad_norm:.6f}, Mean: {grad_mean:.6f}, Std: {grad_std:.6f}") + + # 遍历模型的所有命名参数 + for name, param in model_to_hook.named_parameters(): + if param.requires_grad: + # 使用 .register_hook() 为张量注册hook + handle = param.register_hook(hook_fn) + self._grad_hooks.append(handle) + print(f" [INFO] Registered gradient hook for: {name}") + + def remove_gradient_hooks(self): + """ + 移除所有已注册的梯度hook,在评估或部署时调用。 + """ + for handle in self._grad_hooks: + handle.remove() + self._grad_hooks.clear() + print("[INFO] All gradient hooks removed.") + + def _analyze_latent_representation(self, latent_states: torch.Tensor, timesteps: torch.Tensor, game_states: torch.Tensor, step_counter: int): + """ + 分析并记录 latent states 的统计信息和t-SNE可视化。 + 【新功能】:在t-SNE图上直接显示对应的游戏图像。 + + Args: + latent_states (torch.Tensor): Encoder的输出, shape (B, L, E) 或 (B*L, 1, E) + timesteps (torch.Tensor): 对应的时间步, shape (B, L) + game_states (torch.Tensor): 原始的游戏观测, shape (B, L, C, H, W) + step_counter (int): 全局训练步数,用于控制可视化频率 + """ + # 确保 latent_states 和 game_states 的形状为 (N, ...) + if latent_states.dim() > 2: + latent_states = latent_states.reshape(-1, latent_states.shape[-1]) + + # game_states shape is (B, L, C, H, W), reshape to (B*L, C, H, W) + num_c, num_h, num_w = game_states.shape[-3:] + game_states = game_states.reshape(-1, num_c, num_h, num_w) + + # 1. 统计分析 (Stability Check) - 这部分不变 + with torch.no_grad(): + l2_norm = torch.norm(latent_states, p=2, dim=1).mean() + mean = latent_states.mean() + std = latent_states.std() + abs_max = latent_states.abs().max() + + # 假设您有logger + # logger.add_scalar('debug/latent_l2_norm', l2_norm.item(), step_counter) + # ... + print(f"[Step {step_counter}] Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}, Max Abs: {abs_max:.4f}") + + # 2. 带图像的 t-SNE 可视化 (Discriminability and Consistency Check) + if step_counter % 1000 == 0: + print(f"[Step {step_counter}] Performing t-SNE analysis with images...") + + # 将数据转换到CPU + latents_np = latent_states.detach().cpu().numpy() + images_np = game_states.detach().cpu().numpy() + + # 执行 t-SNE + tsne = TSNE(n_components=2, perplexity=30, n_iter=300, random_state=42) + tsne_results = tsne.fit_transform(latents_np) + + # --- 新增:绘制带图像的散点图 --- + + # 为了性能和清晰度,随机抽样一部分点来显示图像 + num_points_to_plot = min(len(latents_np), 150) # 最多绘制150张图 + indices = np.random.choice(len(latents_np), num_points_to_plot, replace=False) + + fig, ax = plt.subplots(figsize=(16, 14)) + + for i in indices: + # 获取 t-SNE 坐标 + x, y = tsne_results[i] + + # 获取对应的图像 + # PyTorch (CHW) -> Matplotlib (HWC) + img = images_np[i].transpose(1, 2, 0) + + # 重要:确保图像数据是可显示的格式,例如 [0, 1] 的浮点数或 [0, 255] 的整数。 + # 如果您的数据经过了标准化(例如到[-1, 1]),需要先反标准化。 + # 假设数据是 [0, 1] 范围 + img = np.clip(img, 0, 1) + + # 使用 OffsetImage 和 AnnotationBbox 将图像放置在图上 + im = OffsetImage(img, zoom=0.5) # zoom 控制图像缩放比例 + ab = AnnotationBbox(im, (x, y), frameon=False, pad=0.0) + ax.add_artist(ab) + + # 更新图的边界并自动缩放 + ax.update_datalim(tsne_results) + ax.autoscale() + + ax.set_title(f't-SNE of Latent States with Images at Step {step_counter}') + ax.set_xlabel('t-SNE dimension 1') + ax.set_ylabel('t-SNE dimension 2') + + # 保存图像 + save_path = f'zoo/atari/unizero_mspacman_analyze/tsne_with_images_step_{step_counter}.png' + plt.savefig(save_path) + plt.close() + print(f"t-SNE plot with images saved to {save_path}") + + def _debug_check_for_stale_pointers(self, env_id: int, current_key: Any, index_to_be_written: int): + """ + 调试函数:检查即将被写入的索引是否存在过时的指针。 + """ + # 获取对应环境的指针映射表 + cache_map = self.past_kv_cache_init_infer_envs[env_id] + + # 遍历映射表中的所有条目 (旧哈希 -> 旧索引) + for old_key, old_index in cache_map.items(): + # 检查条件: + # 1. 旧索引 == 即将被覆盖的索引 + # 2. 旧哈希 != 当前要写入的新哈希 + if old_index == index_to_be_written and old_key != current_key: + # 如果条件满足,说明我们找到了一个过时指针 + self.stale_pointer_detections += 1 + + # 打印详细的调试信息 + print("="*60) + print(f"!!! INIT BUG CONDITION DETECTED (Detection #{self.stale_pointer_detections}) !!!") + print(f" Environment ID: {env_id}") + print(f" Pool Index to be overwritten: {index_to_be_written}") + print(f" New state hash being written: '{current_key}'") + print(f" Stale pointer found in cache_map: '{old_key}' also points to index {old_index}.") + print(f" This means the data for '{old_key}' is about to be lost, but its pointer remains.") + print(f" Current cache_map size: {len(cache_map)}") + print("="*60) + + # 找到一个就足够了,可以提前退出循环以提高效率 + return + + def _debug_check_for_stale_pointers_recur(self, current_key: Any, index_to_be_written: int): + """ + 调试函数:检查 recurrent cache 中是否存在过时的指针。 + """ + cache_map = self.past_kv_cache_recurrent_infer + + for old_key, old_index in cache_map.items(): + if old_index == index_to_be_written and old_key != current_key: + self.stale_pointer_detections_recur += 1 + print("="*60) + print(f"!!! RECURRENT BUG DETECTED (Detection #{self.stale_pointer_detections_recur}) !!!") + print(f" Pool Index to be overwritten: {index_to_be_written}") + print(f" New state hash being written: '{current_key}'") + print(f" Stale pointer found: '{old_key}' also points to index {old_index}.") + print("="*60) + return + + def _get_final_norm(self, norm_option: str) -> nn.Module: + """ + Return the corresponding normalization module based on the specified normalization option. + """ + if norm_option == 'LayerNorm': + return nn.LayerNorm(self.config.embed_dim, eps=1e-5) + elif norm_option == 'SimNorm': + return SimNorm(simnorm_dim=self.config.group_size) + elif norm_option == 'L2Norm': + return L2Norm(eps=1e-6) + else: + raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") + + def custom_copy_kv_cache_to_shared_init_envs(self, src_kv: KeysValues, env_id) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for a specific environment in the init_infer stage. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + - env_id (:obj:`int`): The identifier of the environment for which the cache is being copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] is None: + self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index_init_envs[env_id] + self.shared_pool_index_init_envs[env_id] = (self.shared_pool_index_init_envs[env_id] + 1) % self.shared_pool_size_init + + return index + + def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for world model usage. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_wm[self.shared_pool_index_wm] is None: + self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_wm[self.shared_pool_index_wm] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + self.shared_pool_index_wm = (self.shared_pool_index_wm + 1) % self.shared_pool_size_wm + + return dst_kv + + def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for recurrent inference. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_recur_infer[self.shared_pool_index] is None: + self.shared_pool_recur_infer[self.shared_pool_index] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_recur_infer[self.shared_pool_index] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size_recur + + return index + + def _initialize_config_parameters(self) -> None: + """Initialize configuration parameters.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.env_num = self.config.env_num + self.num_layers = self.config.num_layers + self.obs_per_embdding_dim = self.config.embed_dim + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + + def _initialize_patterns(self) -> None: + """Initialize patterns for block masks.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + # def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + # """Create head modules for the transformer.""" + # modules = [ + # nn.Linear(self.config.embed_dim, self.config.embed_dim), + # nn.GELU(approximate='tanh'), + # nn.Linear(self.config.embed_dim, output_dim) + # ] + # if norm_layer: + # modules.append(norm_layer) + # return Head( + # max_blocks=self.config.max_blocks, + # block_mask=block_mask, + # head_module=nn.Sequential(*modules) + # ) + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, use_norm_in_head: bool = False) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.Linear(self.config.embed_dim, self.config.embed_dim), + ] + + # ==================== PROPOSED FIX ==================== + # Add a LayerNorm after the first linear layer and before the activation. + # This stabilizes the activations within the head, preventing drift. + if use_norm_in_head: + modules.append(nn.LayerNorm(self.config.embed_dim)) + # ====================================================== + + modules.extend([ + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim), + nn.LayerNorm(output_dim) + ]) + + if norm_layer: + modules.append(norm_layer) + + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_cont(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + from ding.model.common import ReparameterizationHead + self.fc_policy_head = ReparameterizationHead( + input_size=self.config.embed_dim, + output_size=output_dim, + layer_num=2, # TODO: check the effect of layer_num + sigma_type=self.sigma_type, + activation=nn.GELU(approximate='tanh'), + fixed_sigma_value=self.config.fixed_sigma_value if self.sigma_type == 'fixed' else 0.5, + norm_type=None, + bound_type=self.bound_type + ) + return PolicyHeadCont( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=self.fc_policy_head + ) + + def _initialize_last_layer(self) -> None: + """Initialize the last linear layer.""" + last_linear_layer_init_zero = True # TODO + if last_linear_layer_init_zero: + if self.continuous_action_space: + module_to_initialize = [self.head_value, self.head_rewards, self.head_observations] + else: + module_to_initialize = [self.head_policy, self.head_value, self.head_rewards, self.head_observations] + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + from collections import defaultdict + # self.past_kv_cache_recurrent_infer = defaultdict(dict) + # 使用 LRUCache 替换 defaultdict,并同步容量 + + # ========================= 核心修复与注释 (Recurrent Infer) ========================= + # 问题: recurrent_infer 缓存同样存在 LRUCache 与环形缓冲区逻辑不匹配的问题。 + # + # 修复方案: + # 1. 将 past_kv_cache_recurrent_infer 从 LRUCache 改为标准字典。 + # 2. 引入辅助列表 pool_idx_to_key_map_recur_infer 来维护反向映射。 + # 这确保了在覆写 recurrent 数据池中的条目时,可以同步删除旧的指针。 + + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + # ========================== 修复结束 ========================== + + # self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + + # TODO(pu): 非常重要 self.past_kv_cache_init_infer_envs应该改成和(shared_pool_size_init)完全一致, + # 目前是将shared_pool_size_init设置为segment_length以在一次collect后 清空self.past_kv_cache_init_infer_envs + # 来避免self.past_kv_cache_init_infer_envs里面存有kv索引过期的问题 + + # ========================= 核心修复与注释 ========================= + # 原来的实现: + # self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + # + # 问题: defaultdict 会无限增长,并且不会自动删除与环形缓冲区中 + # 被覆盖数据相关的旧“指针”,导致Episode内部的缓存污染。 + # + # 修复方案: + # 使用我们定义的LRUCache,其容量与环形缓冲区的大小(shared_pool_size_init)完全一致。 + # + # 效果: + # 1. 自动淘汰: 当添加第 N+1 个新条目时,LRUCache会自动删除最旧的那个条目。 + # 2. 生命周期同步: 这确保了“指针字典”中的映射关系,与“数据池”中实际存储的数据 + # 完全同步。当数据池的索引0被新数据覆盖时,指向旧索引0的指针也已被自动清除。 + # 3. 杜绝污染: 从根本上解决了Episode内部的状态哈希碰撞问题。 + + # self.past_kv_cache_init_infer_envs = [LRUCache(self.shared_pool_size_init-1) for _ in range(self.env_num)] + # ========================== 修复结束 ========================== + + # ========================= 核心修复与注释 ========================= + # 问题: LRUCache 的淘汰逻辑(基于访问顺序)与环形缓冲区的覆写逻辑(基于写入顺序)不匹配,导致指针过时。 + # + # 修复方案: + # 1. 使用一个标准的字典 `past_kv_cache_init_infer_envs` 来存储 {state_hash -> pool_index}。 + # 2. 引入一个辅助列表 `pool_idx_to_key_map_init_envs` 来维护反向映射 {pool_index -> state_hash}。 + # + # 效果: + # 在向环形缓冲区的某个索引写入新数据之前,我们可以通过辅助列表立即找到即将被覆盖的旧 state_hash, + # 并从主字典中精确地删除这个过时的条目。这确保了字典和数据池的完全同步。 + + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # 辅助数据结构,用于反向查找:pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + # ========================== 修复结束 ========================== + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initialize the projection input dimension based on the number of observation tokens.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + self.projection_input_dim = self.obs_per_embdding_dim + + def _initialize_statistics(self) -> None: + """Initialize counters for hit count and query count statistics.""" + self.recur_hit_count = 0 + self.recur_total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.length_largethan_contextminus3_cnt = 0 + + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + def _initialize_transformer_keys_values(self) -> None: + """Initialize keys and values for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm_single_env_tmp = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, + max_tokens=self.context_length) + + def precompute_pos_emb_diff_kv(self): + """ Precompute positional embedding differences for key and value. """ + if self.context_length <= 2: + # If context length is 2 or less, no context is present + return + # Precompute positional embedding matrices for inference in collect/eval stages, not for training + self.positional_embedding_k = [ + self._get_positional_embedding(layer, 'key') + for layer in range(self.config.num_layers) + ] + self.positional_embedding_v = [ + self._get_positional_embedding(layer, 'value') + for layer in range(self.config.num_layers) + ] + + # Precompute all possible positional embedding differences + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + if torch.cuda.is_available(): + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).to(self.device).detach() + else: + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).detach() + + def forward( + self, + obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, Tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, + is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, + start_pos: Union[int, List[int]] = 0, + search_depth: Optional[List[int]] = None + ) -> "WorldModelOutput": + """ + Overview: + Forward pass for the world model. This method processes observation embeddings and/or action tokens, + optionally adds position encodings (with or without rotary position embeddings), passes the resulting + sequences through the transformer, and finally generates logits for observations, rewards, policy, and value. + + Arguments: + - obs_embeddings_or_act_tokens (dict): Dictionary containing one or more of the following keys: + - 'obs_embeddings': torch.Tensor representing observation embeddings. + - 'act_tokens': torch.Tensor representing action tokens. + - 'obs_embeddings_and_act_tokens': Combined data for both observations and actions. + - past_keys_values (Optional[torch.Tensor]): Cached key-value pairs for the transformer. Defaults to None. + - kvcache_independent (bool): Flag to indicate whether key-value caching is independent. Defaults to False. + - is_init_infer (bool): Flag to indicate if this is the initial inference step. Defaults to True. + - valid_context_lengths (Optional[torch.Tensor]): Valid lengths for the context. Defaults to None. + - start_pos (int or List[int]): Starting positional index for the current sequence (or batch). Defaults to 0. + - search_depth (Optional[List[int]]): List representing the search depth for each batch element, used for + position encoding adjustment. Defaults to None. + + Returns: + WorldModelOutput: An output instance containing: + - x: Output features from the transformer. + - logits for observations. + - logits for rewards. + - logits_ends (None). + - logits for policy. + - logits for value. + """ + + # Calculate previous steps based on key-value caching configuration + if kvcache_independent: + # If kv caching is independent, compute previous steps for each past key-value pair. + prev_steps = torch.tensor( + [0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], + device=self.device + ) + else: + # Otherwise, use a single value for previous steps. + prev_steps = 0 if past_keys_values is None else past_keys_values.size + + # Reset valid context lengths during initial inference phase. + if is_init_infer: + valid_context_lengths = None + + # sequences: torch.Tensor # Output sequence to feed into transformer + # num_steps: int # Number of timesteps in the sequence + # start_pos_adjusted: Union[int, List[int]] # Adjusted starting position index for positional encoding + + if not self.config.rotary_emb: + start_pos_adjusted = None + + # Process observation embeddings if available. + if "obs_embeddings" in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens["obs_embeddings"] + # If the observation embeddings have 2 dimensions, expand them to include a time dimension. + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + num_steps = obs_embeddings.size(1) + + if not self.config.rotary_emb: + # Add traditional position embeddings if not using rotary embeddings. + sequences = self._add_position_embeddings( + obs_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths + ) + else: + # Keep the observation embeddings unchanged when using rotary embeddings. + sequences = obs_embeddings + + if is_init_infer: + if self.reanalyze_phase: + # During reanalyze phase in initial inference, adjust start_pos: + # Multiply by 2 because timestep only counts observations, + # but the sequence contains both observations and actions. + start_pos_adjusted = start_pos * 2 + if not isinstance(start_pos_adjusted, (int, float)): + # Pad zero if start_pos_adjusted is not a scalar. + padding = np.zeros((start_pos_adjusted.shape[0], 1), dtype=start_pos_adjusted.dtype) + start_pos_adjusted = np.concatenate([start_pos_adjusted, padding], axis=1).reshape(-1) + else: + # For regular initial inference, adjust start_pos accordingly. + if isinstance(start_pos, (int, float)): + start_pos_adjusted = start_pos * 2 + else: + start_pos_adjusted = [pos * 2 for pos in start_pos] + else: + # For recurrent inference (non-init), calculate the correct positional index. + if self.reanalyze_phase: + # In reanalyze phase, start_pos for batch mode might be an array that needs padding. + if not isinstance(start_pos, (int, float)): + padding = np.zeros((start_pos.shape[0], 1), dtype=start_pos.dtype) + start_pos_adjusted = np.concatenate([start_pos, padding], axis=1).reshape(-1) + # Ensure search_depth length matches adjusted start_pos. + assert len(search_depth) == len(start_pos_adjusted) + start_pos_adjusted = [ + (search_depth[i] + pos + 1) * 2 + 1 for i, pos in enumerate(start_pos_adjusted) + ] + else: + start_pos_adjusted = [ + (search_depth[i] + pos) * 2 + 2 for i, pos in enumerate(start_pos) + ] + + # Process action tokens if available. + elif "act_tokens" in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens["act_tokens"] + if self.continuous_action_space: + num_steps = 1 + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(1) + else: + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + # Convert action tokens to embeddings using the action embedding table. + act_embeddings = self.act_embedding_table(act_tokens) + if not self.config.rotary_emb: + sequences = self._add_position_embeddings( + act_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths + ) + else: + sequences = act_embeddings + + if is_init_infer: + if self.reanalyze_phase: + # In reanalyze phase during initial inference, the action tokens represent the current timestep. + start_pos_adjusted = start_pos * 2 + 1 + if not isinstance(start_pos_adjusted, (int, float)): + padding = np.zeros((start_pos_adjusted.shape[0], 1), dtype=start_pos_adjusted.dtype) + start_pos_adjusted = np.concatenate([start_pos_adjusted, padding], axis=1).reshape(-1) + else: + # For regular initial inference using action tokens, adjust start_pos by subtracting 1. + if isinstance(start_pos, (int, float)): + start_pos_adjusted = start_pos * 2 - 1 + else: + start_pos_adjusted = [pos * 2 - 1 for pos in start_pos] + else: + # During recurrent inference for action tokens. + if self.reanalyze_phase: + if not isinstance(start_pos, (int, float)): + padding = np.zeros((start_pos.shape[0], 1), dtype=start_pos.dtype) + start_pos_adjusted = np.concatenate([start_pos, padding], axis=1).reshape(-1) + assert len(search_depth) == len(start_pos_adjusted) + start_pos_adjusted = [ + (search_depth[i] + pos + 1) * 2 + 1 for i, pos in enumerate(start_pos_adjusted) + ] + else: + start_pos_adjusted = [ + (search_depth[i] + pos) * 2 + 1 for i, pos in enumerate(start_pos) + ] + + # Process combined observation embeddings and action tokens. + elif "obs_embeddings_and_act_tokens" in obs_embeddings_or_act_tokens: + # Process combined inputs to calculate either the target value (for training) + # or target policy (for reanalyze phase). + if self.continuous_action_space: + sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps) + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + # Adjust start positions: multiply by 2 as the sequence has both obs and act. + start_pos_adjusted = [pos * 2 for pos in start_pos] + else: + raise ValueError("Input dictionary must contain one of 'obs_embeddings', 'act_tokens', or 'obs_embeddings_and_act_tokens'.") + + # Pass the sequence through the transformer. + x = self._transformer_pass( + sequences, past_keys_values, kvcache_independent, valid_context_lengths, start_pos=start_pos_adjusted + ) + + # print(f"x.mean(): {x.mean().item():.6f}, x.std(): {x.std().item():.6f}") + + # Generate logits for various components. + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + + # print(f"logits_observations.mean(): {logits_observations.mean().item():.6f}") + # print(f"logits_rewards.mean(): {logits_rewards.mean().item():.6f}") + # print(f"logits_policy.mean(): {logits_policy.mean().item():.6f}") + # print(f"logits_value.mean(): {logits_value.mean().item():.6f}") + + # The 'logits_ends' is intentionally set to None. + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, + valid_context_lengths): + """ + Add position embeddings to the input embeddings. + + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + - num_steps (:obj:`int`): Number of steps. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Embeddings with position information added. + """ + if kvcache_independent: + steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + return embeddings + position_embeddings + else: + if is_init_infer: + return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + else: + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + position_embeddings = self.pos_emb( + valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + return embeddings + position_embeddings + + def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + if self.continuous_action_space: + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: # TODO + act_tokens = act_tokens.unsqueeze(-1) + + # B, L, E + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + act = act_embeddings[:, i, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return_result = obs_act_embeddings + if not self.config.rotary_emb: + return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + return return_result, num_steps + + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + act = act_embeddings[:, i, 0, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return_result = obs_act_embeddings + if not self.config.rotary_emb: + return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + return return_result, num_steps + + def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths, start_pos: int = 0): + """ + Pass sequences through the transformer. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Transformer output. + """ + if kvcache_independent: + x = [self.transformer(sequences[k].unsqueeze(0), past_kv, + valid_context_lengths=valid_context_lengths[k].unsqueeze(0), start_pos=start_pos) for k, past_kv in + enumerate(past_keys_values)] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos) + + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor: + """ + Reset the model state based on initial observations and actions. + + Arguments: + - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + Returns: + - torch.FloatTensor: The outputs from the world model and the latent state. + """ + # Extract observations, actions, and current observations from the dictionary. + if isinstance(obs_act_dict, dict): + batch_obs = obs_act_dict['obs'] # obs_act_dict['obs'] is at timestep t + batch_action = obs_act_dict['action'] # obs_act_dict['action'] is at timestep t + batch_current_obs = obs_act_dict['current_obs'] # obs_act_dict['current_obs'] is at timestep t+1 + + # Encode observations to latent embeddings. + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs) + + if batch_current_obs is not None: + # ================ Collect and Evaluation Phase ================ + # Encode current observations to latent embeddings + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + self.latent_state = current_obs_embeddings + outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, + current_obs_embeddings, start_pos) + else: + # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ + self.latent_state = obs_embeddings + outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, None, start_pos) + + return outputs_wm, self.latent_state + + @torch.no_grad() + def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, start_pos: int = 0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - last_obs_embeddings (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + # Determine whether it is the first step in an episode. + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # ------------------------- First Step of an Episode ------------------------- + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, start_pos=start_pos) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # --------------------- Continuing an Episode (Multi-environment) --------------------- + # current_obs_embeddings is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + for i in range(ready_env_num): + # Retrieve latent state for a single environment + # TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done + + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state(state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # if self.root_total_query_cnt > 0 and self.root_total_query_cnt % 50 == 0: + # self.root_hit_freq = self.root_hit_cnt / self.root_total_query_cnt + # print('root total_query_count:', self.root_total_query_cnt) + # print('root root_hit_freq:', self.root_hit_freq) + + # NOTE: deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + # If using RoPE positional encoding, then at reset, the pos_embed should use the absolute position start_pos[i]. + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, start_pos=start_pos[i].item()) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + start_pos = start_pos[:ready_env_num] + # TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done + # TODO: the order may be not correct? len(batch_action) may smaller than len(current_obs_embeddings), because some environments may have done + batch_action = batch_action[:ready_env_num] + + # TODO: only for debug + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + # print(f"ready_env_num: {ready_env_num}") + # print(f"start_pos: {start_pos}") + # print(f"batch_action: {batch_action}") + # print(f"len(last_obs_embeddings): {len(last_obs_embeddings)}") + # print(f"len(batch_action): {len(batch_action)}") + # print(f"len(current_obs_embeddings): {len(current_obs_embeddings)}") + + if self.continuous_action_space: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) + else: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, start_pos=start_pos) + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, start_pos=start_pos) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + elif batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + if self.continuous_action_space: + act_tokens = batch_action + else: + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + # Each sample in the batch (last_obs_embeddings, act_tokens) corresponds to the same time step, and start_pos also corresponds to each sample's respective t. + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, start_pos=start_pos) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, start_pos) + # TODO(pu): 由于预测误差的存在,不clear,也很可能不能检索到上次mcts 树搜索中的节点 + # 所有collect env公用应该也是合理的,不同环境很难遇到完全一致的预测的latent state? + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + search_depth=[], start_pos: int = 0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - search_depth (:obj:`list`, optional): List containing depth of latent states in the search tree. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, start_pos) + + latent_state_list = [] + if not self.continuous_action_space: + token = action.reshape(-1, 1) + else: + token = action.reshape(-1, self.action_space_size) + + # ======= Print statistics for debugging ============= + min_size = min(self.keys_values_wm_size_list) + # # if min_size >= self.config.max_tokens - 5: + # # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # # if min_size >= self.config.max_tokens - 7: + # # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.context_length - 3: + # self.length_largethan_contextminus3_cnt += len(self.keys_values_wm_size_list) + # # if self.recur_total_query_count > 0 and self.recur_total_query_count % 10000 == 0: + # if self.recur_total_query_count > 0 and self.recur_total_query_count % 1000 == 0: + # self.hit_freq = self.recur_hit_count / self.recur_total_query_count + # print('recur total_query_count:', self.recur_total_query_count) + # # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.recur_total_query_count + # # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.recur_total_query_count + # # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + # length_largethan_contextminus3_cnt_ratio = self.length_largethan_contextminus3_cnt / self.recur_total_query_count + # print('recurrent length_largethan_contextminus3_cnt_ratio:', length_largethan_contextminus3_cnt_ratio) + # print('recurrent length_largethan_contextminus3_cnt:', self.length_largethan_contextminus3_cnt) + + # Trim and pad kv_cache: modify self.keys_values_wm in-place + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + start_pos=start_pos, + search_depth=search_depth # List containing depth of latent states in the search tree. + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + # TODO: precompute_pos_emb_diff_kv 与 update_cache_context 的硬编码不匹配,collect_env_num=1应该没有问题 + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + search_depth=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - search_depth (:obj:`list`): List of depth indices in the search tree. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor + context_length = self.context_length + + if not is_init_infer: + + + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + if not self.config.rotary_emb: + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = \ + self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze( + 0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = \ + self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + if not self.config.rotary_emb: + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + if is_init_infer: + # TODO + # ==================== 主动淘汰修复逻辑 ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index_init_envs[i] + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + # 确保要删除的键确实存在,避免意外错误 + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # 现在可以安全地写入新数据了 + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. 在主 cache map 和辅助列表中同时更新新的映射关系 + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key + + # 调用调试函数进行检查 + self._debug_check_for_stale_pointers(env_id=i, current_key=cache_key, index_to_be_written=index_to_write) + # ============================================================ + + # Store the latest key-value cache for initial inference + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + else: + # TODO 获取要存入的cache的某个唯一标识,例如tensor的和 + # cache_to_store = self.keys_values_wm_single_env._keys_values[0]._k_cache._cache + # cache_sum = torch.sum(cache_to_store).item() + # cache_shape = cache_to_store.shape + # print(f"[CACHE WRITE] Storing for key={cache_key}, cache_shape={cache_shape}, cache_sum={cache_sum:.4f}") + + # ==================== RECURRENT INFER FIX ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. 现在可以安全地写入新数据了 + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + + # 5. 在主 cache map 和辅助列表中同时更新新的映射关系 + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key + # ============================================================ + + # ==================== DEBUG CODE INSERTION ==================== + # 调用调试函数进行检查 + self._debug_check_for_stale_pointers_recur(current_key=cache_key, index_to_be_written=index_to_write) + # ============================================================ + + # Store the latest key-value cache for recurrent inference + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, start_pos: int = 0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for index in range(ready_env_num): + self.recur_total_query_count += 1 + state_single_env = latent_state[index] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[index][cache_index] + + # TODO + # retrieved_cache = matched_value._keys_values[0]._k_cache._cache + # retrieved_sum = torch.sum(retrieved_cache).item() + # retrieved_shape = retrieved_cache.shape + # print(f"[CACHE HIT] Found for key={cache_key}, retrieved_shape={retrieved_shape}, retrieved_sum={retrieved_sum:.4f}") + + + else: + matched_value = None + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + # if matched_value is None: + # matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + # ==================== 核心修复 ==================== + # 步骤 2: 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找 + if matched_value is None: + # 2.1 安全地从字典中获取索引,它可能返回 None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # 2.2 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值 + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") + + # ================================================= + # # TODO + # retrieved_cache = matched_value._keys_values[0]._k_cache._cache + # retrieved_sum = torch.sum(retrieved_cache).item() + # retrieved_shape = retrieved_cache.shape + # print(f"[CACHE HIT] Found for key={cache_key}, retrieved_shape={retrieved_shape}, retrieved_sum={retrieved_sum:.4f}") + + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.recur_hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # print(f"[CACHE MISS] Not found for key={cache_key}. Generating new cache.") + + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + + # Determine the absolute start position based on the reanalyze phase flag. + if self.reanalyze_phase: + num_rows, num_cols = start_pos.shape # Original start_pos shape is (batch, num_columns) + total_cols = num_cols + 1 # Each logical row is extended by one column. + row_idx = index // total_cols + col_idx = index % total_cols + # If the column index equals the original number of columns, this indicates the added column; set to 0. + start_pos_adjusted: int = 0 if col_idx == num_cols else int(start_pos[row_idx, col_idx]) + else: + start_pos_adjusted = int(start_pos[index].item()) + + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, start_pos=start_pos_adjusted + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, + **kwargs: Any) -> LossWithIntermediateLosses: + start_pos = batch['timestep'] + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) + + + # ======================= 在这里插入分析代码 ======================= + # 从kwargs获取全局step,假设您在训练循环中传入了它 + global_step = kwargs.get('global_step', 0) + # 为了避免影响训练,可以控制调用频率 + if global_step % 10 == 0: # 每100个training step分析一次 + self._analyze_latent_representation( + latent_states=obs_embeddings, + timesteps=batch['timestep'], + game_states=batch['observations'], # 传入原始图像 + step_counter=global_step + ) + # ================================================================= + + # ========= for visual analysis ========= + # Uncomment the lines below for visual analysis in Pong + # self.plot_latent_tsne_each_and_all_for_pong(obs_embeddings, suffix='pong_H10_H4_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='pong_H10_H4_tsne') + # Uncomment the lines below for visual analysis in visual match + # self.plot_latent_tsne_each_and_all(obs_embeddings, suffix='visual_match_memlen1-60-15_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') + + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), + percentage=self.dormant_threshold) + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + with torch.no_grad(): + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + # Calculate the L2 norm of the latent action + latent_action_l2_norms = torch.norm(self.act_embedding_table(act_tokens), p=2, dim=2).mean() + + if self.config.latent_norm_loss: + # ==================== L2惩罚损失计算(最终修复版 v2) ==================== + # 1. 计算每个 latent_state 向量的L2范数的平方。 + # 根据调试信息,obs_embeddings shape: (B*L, 1, E) + # 所以 latent_norm_sq shape: (B*L, 1) + latent_norm_sq = torch.norm(obs_embeddings, p=2, dim=-1).pow(2) + # 2. 获取源掩码。 + # 根据调试信息,mask_source shape: (B, L) + mask_source = batch['mask_padding'] + # 3. 将源掩码从 (B, L) reshape 为 (B*L, 1),以匹配 latent_norm_sq 的形状。 + # 这是解决维度不匹配错误的关键。 + # 我们使用 view(-1, 1) 来实现这个变形。 + correct_mask = mask_source.contiguous().view(-1, 1) + # 4. 检查变形后的形状是否匹配。 + # 这是一个防御性编程,确保两个张量的第一个维度是相同的。 + if latent_norm_sq.shape[0] != correct_mask.shape[0]: + # 如果形状不匹配,打印错误信息并抛出异常,这能帮助我们更快地定位未来可能出现的新问题。 + raise RuntimeError( + f"Shape mismatch for L2 norm loss calculation! " + f"latent_norm_sq shape: {latent_norm_sq.shape}, " + f"but correct_mask shape after reshape is: {correct_mask.shape}. " + f"Original mask_source shape was: {mask_source.shape}" + ) + # 5. 直接进行逐元素乘法。因为现在它们的形状都是 (B*L, 1),所以可以安全相乘。 + masked_latent_norm_sq = latent_norm_sq * correct_mask + # 6. 计算平均损失。分母是掩码中所有“1”的总和,代表有效的元素数量。 + # 增加一个极小值 epsilon (1e-8) 防止分母为零。 + latent_norm_loss = masked_latent_norm_sq.sum() / (correct_mask.sum() + 1e-8) + # ================================================================= + else: + latent_norm_loss = torch.tensor(0.) + + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # ========== Calculate reconstruction loss and perceptual loss ============ + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = self.latent_recon_loss + + elif self.obs_type == 'text': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=torch.float32) + decode_loss_mode = self.config.decode_loss_mode + + # Reconstruction loss for predicting the next latent (via backbone) + # input -> encoder -> backbone(unizero) -> decoder -> latent_recon_loss + if decode_loss_mode == "after_backbone": + next_latent_state = outputs.logits_observations[:, :-1, :] + next_target_ids = batch['observations'][:, 1:, :] + + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=next_latent_state, + target_ids=next_target_ids, + ).loss + + #Reconstruction loss for predicting the current latent (without using the backbone) + # input -> encoder -> decoder -> latent_recon_loss + elif decode_loss_mode == "before_backbone": + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=obs_embeddings, + target_ids=batch['observations'], + ).loss + + else: + latent_recon_loss = self.latent_recon_loss + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = cal_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + percentage=self.dormant_threshold) + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_world_model = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + with torch.no_grad(): + # For training stability, use target_tokenizer to compute the true next latent state representations + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations']) + # target_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, _ = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # print('loss_obs:', loss_obs.mean()) + # assert not torch.isnan(loss_obs).any(), "loss_obs contains NaN values" + # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" + # for name, param in self.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_value, labels_policy = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + + if not self.continuous_action_space: + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, + batch, + element='policy') + else: + # NOTE: for continuous action space + if self.config.policy_loss_type == 'simple': + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont_simple(outputs, batch) + else: + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont(outputs, batch) + + loss_policy = orig_policy_loss + self.policy_entropy_weight * policy_entropy_loss + policy_entropy = - policy_entropy_loss + + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # ==== TODO: calculate the new priorities for each transition. ==== + # value_priority = L1Loss(reduction='none')(labels_value.squeeze(-1), outputs['logits_value'][:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_timestep = seq_len // 2 + middle_step_mask = mask_padding[:, middle_timestep] + middle_step_losses[loss_name] = loss_tmp[:, middle_timestep][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + latent_action_l2_norms=latent_action_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + latent_norm_loss=latent_norm_loss, # 新增 + value_priority=value_priority, + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + latent_action_l2_norms=latent_action_l2_norms, + latent_norm_loss=latent_norm_loss, # 新增 + value_priority=value_priority, + + ) + + + # TODO: test correctness + def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): + """ + Simplified policy loss calculation for continuous actions. + + Args: + - outputs: Model outputs containing policy logits. + - batch (:obj:`dict`): Batch data containing target policy, mask and sampled actions. + + Returns: + - policy_loss (:obj:`torch.Tensor`): The simplified policy loss. + """ + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size + + # Get the policy logits and batch data + policy_logits_all = outputs.logits_policy + mask_batch = batch['mask_padding'].contiguous().view(-1) + target_policy = batch['target_policy'].contiguous().view(batch_size * num_unroll_steps, -1) + target_sampled_actions = batch['child_sampled_actions'].contiguous().view(batch_size * num_unroll_steps, -1, action_space_size) + + # Flatten for vectorized computation + policy_logits_all = policy_logits_all.view(batch_size * num_unroll_steps, -1) + + # Extract mean and standard deviation from logits + mu, sigma = policy_logits_all[:, :action_space_size], policy_logits_all[:, action_space_size:] + dist = Independent(Normal(mu, sigma), 1) # Create the normal distribution + + # Find the indices of the maximum values in the target policy + target_best_action_idx = torch.argmax(target_policy, dim=1) + + # Select the best actions based on the indices + target_best_action = target_sampled_actions[torch.arange(target_best_action_idx.size(0)), target_best_action_idx] + + # Clip the target actions to prevent numerical issues during arctanh + # target_best_action_clamped = torch.clamp(target_best_action, -1 + 1e-6, 1 - 1e-6) + target_best_action_clamped = torch.clamp(target_best_action, -0.999, 0.999) + target_best_action_before_tanh = torch.arctanh(target_best_action_clamped) + + # Calculate the log probability of the best action + log_prob_best_action = dist.log_prob(target_best_action_before_tanh) + + # Mask the log probability with the padding mask + log_prob_best_action = log_prob_best_action * mask_batch + + # Return the negative log probability as the policy loss (we want to maximize log_prob) + # policy_loss = -log_prob_best_action.mean() + policy_loss = -log_prob_best_action + + policy_entropy = dist.entropy().mean() + policy_entropy_loss = -policy_entropy * mask_batch + # Calculate the entropy of the target policy distribution + non_masked_indices = torch.nonzero(mask_batch).squeeze(-1) + if len(non_masked_indices) > 0: + target_normalized_visit_count = target_policy.contiguous().view(batch_size * num_unroll_steps, -1) + target_dist = Categorical(target_normalized_visit_count[non_masked_indices]) + target_policy_entropy = target_dist.entropy().mean().item() + else: + target_policy_entropy = 0.0 + + return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma + + def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calculate the policy loss for continuous actions. + + Args: + - outputs: Model outputs containing policy logits. + - batch (:obj:`dict`): Batch data containing target policy, mask and sampled actions. + Returns: + - policy_loss (:obj:`torch.Tensor`): The calculated policy loss. + - policy_entropy_loss (:obj:`torch.Tensor`): The entropy loss of the policy. + - target_policy_entropy (:obj:`float`): The entropy of the target policy distribution. + - target_sampled_actions (:obj:`torch.Tensor`): The actions sampled from the target policy. + - mu (:obj:`torch.Tensor`): The mean of the normal distribution. + - sigma (:obj:`torch.Tensor`): The standard deviation of the normal distribution. + """ + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size + + policy_logits_all = outputs.logits_policy + mask_batch = batch['mask_padding'] + child_sampled_actions_batch = batch['child_sampled_actions'] + target_policy = batch['target_policy'] + + # Flatten the unroll step dimension for easier vectorized operations + policy_logits_all = policy_logits_all.view(batch_size * num_unroll_steps, -1) + mask_batch = mask_batch.contiguous().view(-1) + child_sampled_actions_batch = child_sampled_actions_batch.contiguous().view(batch_size * num_unroll_steps, -1, + action_space_size) + + mu, sigma = policy_logits_all[:, :action_space_size], policy_logits_all[:, action_space_size:] + mu = mu.unsqueeze(1).expand(-1, child_sampled_actions_batch.shape[1], -1) + sigma = sigma.unsqueeze(1).expand(-1, child_sampled_actions_batch.shape[1], -1) + dist = Independent(Normal(mu, sigma), 1) + + target_normalized_visit_count = target_policy.contiguous().view(batch_size * num_unroll_steps, -1) + target_sampled_actions = child_sampled_actions_batch + + policy_entropy = dist.entropy().mean(dim=1) + policy_entropy_loss = -policy_entropy * mask_batch + + # NOTE: Alternative way to calculate the log probability of the target actions + # y = 1 - target_sampled_actions.pow(2) + # target_sampled_actions_clamped = torch.clamp(target_sampled_actions, -1 + 1e-6, 1 - 1e-6) + # target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped) + # log_prob = dist.log_prob(target_sampled_actions_before_tanh) + # log_prob = log_prob - torch.log(y + 1e-6).sum(-1) + # log_prob_sampled_actions = log_prob + + base_dist = Normal(mu, sigma) + tanh_transform = TanhTransform() + dist = TransformedDistribution(base_dist, [tanh_transform]) + dist = Independent(dist, 1) + target_sampled_actions_clamped = torch.clamp(target_sampled_actions, -0.999, 0.999) + # assert torch.all(target_sampled_actions_clamped < 1) and torch.all(target_sampled_actions_clamped > -1), "Actions are not properly clamped." + log_prob = dist.log_prob(target_sampled_actions_clamped) + log_prob_sampled_actions = log_prob + + # KL as projector + target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + policy_loss = -torch.sum( + torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 + ) * mask_batch + + # Calculate the entropy of the target policy distribution + non_masked_indices = torch.nonzero(mask_batch).squeeze(-1) + if len(non_masked_indices) > 0: + target_dist = Categorical(target_normalized_visit_count[non_masked_indices]) + target_policy_entropy = target_dist.entropy().mean().item() + else: + target_policy_entropy = 0.0 + + return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma + + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + if torch.isnan(logits).any(): + raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if torch.isnan(labels).any(): + raise ValueError(f"NaN detected in labels_value for batch {batch} and element '{element}'") + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + if torch.isnan(loss).any(): + raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + # print(f"self.policy_entropy_weight:{self.policy_entropy_weight}") + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_endgs = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return labels_value.reshape(-1, self.support_size), None + else: + return labels_value.reshape(-1, self.support_size), labels_policy.reshape(-1, self.action_space_size) + + def clear_caches(self): + """ + Clears the caches of the world model. + """ + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + print(f'Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index c0fb1536a..72e8e3e4b 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -134,6 +134,10 @@ class EfficientZeroPolicy(MuZeroPolicy): n_episode=8, # (float) the number of simulations in MCTS. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -471,7 +475,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # weighted loss with masks (some invalid states which are out of trajectory.) loss = ( self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + + self._cfg.policy_entropy_weight * (-1)*policy_entropy ) weighted_total_loss = (weights * loss).mean() # TODO(pu): test the effect of gradient scale. @@ -529,10 +534,14 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) self._collect_mcts_temperature = 1 self.collect_epsilon = 0.0 @@ -662,10 +671,14 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: Union[int, List] = [-1], ready_env_id: np.array = None, **kwargs): """ diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 7bd2e8d2b..cdd9e76df 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -163,8 +163,12 @@ class MuZeroPolicy(Policy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS. + # # (int) the number of simulations in MCTS for renalyze. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -678,10 +682,15 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model + + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) self._collect_mcts_temperature = 1. self.collect_epsilon = 0.0 if self._cfg.model.model_type == 'conv_context': @@ -779,9 +788,15 @@ def _forward_collect( # normal collect # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=False + # ) + + # collect buffer data ===================== action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False + distributions, temperature=self._collect_mcts_temperature, deterministic=True ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { @@ -844,10 +859,16 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + if self._cfg.model.model_type == 'conv_context': self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) self.last_batch_action = [-1 for _ in range(3)] diff --git a/lzero/policy/sampled_unizero.py b/lzero/policy/sampled_unizero.py index ec7399fc6..817216c87 100644 --- a/lzero/policy/sampled_unizero.py +++ b/lzero/policy/sampled_unizero.py @@ -410,7 +410,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Prepare action batch and convert to torch tensor if self._cfg.model.continuous_action_space: action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( - -1) # For discrete action space + -1) # For continuous action space else: action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( -1).long() # For discrete action space diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index 19a852f56..bc78f71b2 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -110,6 +110,7 @@ def visit_count_temperature( def phi_transform( discrete_support: DiscreteSupport, x: torch.Tensor, + label_smoothing_eps: float = 0. # <--- 新增平滑参数 ) -> torch.Tensor: """ Overview: @@ -163,7 +164,16 @@ def phi_transform( dtype=x.dtype, device=x.device) target.scatter_add_(-1, idx, prob) - return target + + # --- 5. 应用标签平滑 --- + if label_smoothing_eps > 0: + # 将原始的 two-hot 目标与一个均匀分布混合 + smooth_target = (1.0 - label_smoothing_eps) * target + (label_smoothing_eps / size) + return smooth_target + else: + return target + + # return target def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index f2bfc48f9..b3973c929 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -16,6 +16,77 @@ prepare_obs_stack_for_unizero from lzero.policy.muzero import MuZeroPolicy from .utils import configure_optimizers_nanogpt +from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters +import torch.nn.functional as F + +def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float): + """ + 使用向量化操作高效地缩放一个模块的所有权重。 + """ + if not (0.0 < scale_factor < 1.0): + return # 如果缩放因子无效,则不执行任何操作 + + # 1. 将模块的所有参数展平成一个单一向量 + params_vec = parameters_to_vector(module.parameters()) + + # 2. 在这个向量上执行一次乘法操作 + params_vec.data.mul_(scale_factor) + + # 3. 将缩放后的向量复制回模块的各个参数 + vector_to_parameters(params_vec, module.parameters()) + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + """ + # 1. 定义需要特殊处理的参数 + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + + # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads + transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} + tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + + # Heads的参数是那些既不属于transformer也不属于tokenizer的 + head_params = { + pn: p for pn, p in param_dict.items() + if 'transformer' not in pn and 'tokenizer' not in pn + } + + # 3. 为每组设置不同的优化器参数(特别是学习率) + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': list(transformer_params.values()), + 'lr': learning_rate, # 1e-4 + # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5 + 'weight_decay': weight_decay + # 'weight_decay': weight_decay * 5.0 + }, + { + 'params': list(tokenizer_params.values()), + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5 + # 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + + + }, + { + 'params': list(head_params.values()), + 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 + # 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + # 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer @POLICY_REGISTRY.register('unizero') @@ -113,6 +184,8 @@ class UniZeroPolicy(MuZeroPolicy): perceptual_loss_weight=0., # (float) The weight of the policy entropy loss. policy_entropy_weight=0, + final_norm_option_in_encoder="SimNorm", # "SimNorm"对应"group_kl", "LayerNorm"对应"mse", + final_norm_option_in_obs_head="SimNorm", # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. predict_latent_loss_type='group_kl', # (str) The type of observation. Options are ['image', 'vector']. @@ -132,6 +205,23 @@ class UniZeroPolicy(MuZeroPolicy): max_seq_len=8192, ), ), + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=200000, # 例如,在200k次迭代后达到最终值 + # ===================== END: Encoder-Clip Annealing Config ===================== + # ****** common ****** # (bool) whether to use rnd model. use_rnd_model=False, @@ -198,6 +288,11 @@ class UniZeroPolicy(MuZeroPolicy): optim_type='AdamW', # (float) Learning rate for training policy network. Initial lr for manually decay schedule. learning_rate=0.0001, + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + # ============================================================ + # (int) Frequency of hard target network update. target_update_freq=100, # (int) Frequency of soft target network update. @@ -214,8 +309,12 @@ class UniZeroPolicy(MuZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS. + # # (int) the number of simulations in MCTS for renalyze. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -300,24 +399,102 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'UniZeroModel', ['lzero.model.unizero_model'] + # ==================== [新增] 模型范数监控函数 ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件(Encoder, Transformer, Heads)的参数矩阵范数。 + 此函数应在 torch.no_grad() 环境下调用,以提高效率。 + Returns: + - norm_metrics (:obj:`Dict[str, float]`): 包含所有范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + # 计算单层参数的L2范数 + param_norm = param.data.norm(2).item() + # 替换点号,使其在TensorBoard中正确显示为层级 + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + + # 计算整个模块的总范数 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + # ================================================================= + + def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - # NOTE: nanoGPT optimizer - self._optimizer_world_model = configure_optimizers_nanogpt( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR # TODO: check the total training steps - self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + # self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + # final_lr = self._cfg.get('final_learning_rate', 0.0) + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) + # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) @@ -325,13 +502,22 @@ def _init_learn(self) -> None: assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" self._model = torch.compile(self._model) self._target_model = torch.compile(self._target_model) - # NOTE: soft target - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='momentum', - update_kwargs={'theta': self._cfg.target_update_theta} - ) + if self._cfg.target_model_update_option=="soft": + # NOTE: soft target + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + elif self._cfg.target_model_update_option=="hard": + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model if self._cfg.use_augmentation: @@ -341,8 +527,8 @@ def _init_learn(self) -> None: ) self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) - assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... - assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + # assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + # assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @@ -357,6 +543,84 @@ def _init_learn(self) -> None: wandb.watch(self._learn_model.representation_network, log="all") self.accumulation_steps = self._cfg.accumulation_steps + # 从配置中获取阈值,如果未设置则使用一个合理的默认值(例如20.0) + # 设置为0或负数则禁用此功能 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) # TODO + + + # 从配置中获取阈值,例如 15.0 或 20.0 + self.logit_clip_threshold = self._cfg.get('logit_clip_threshold', 10.0) + # 1. 获取 world_model 的引用,方便后续操作 + world_model = self._learn_model.world_model + # 2. 将参数明确地分为两组:预测头 (heads) 和 主干网络 (backbone) + # - a. 获取所有预测头的参数 + self.head_params = list(world_model.head_value.parameters()) + \ + list(world_model.head_rewards.parameters()) + \ + list(world_model.head_policy.parameters()) + # 如果有其他头,也一并加入 + # - b. 为了高效分离,我们使用参数的ID + self.head_param_ids = {id(p) for p in self.head_params} + # - c. 获取主干网络的参数(所有不在 head_param_ids 中的参数) + self.backbone_params = [p for p in world_model.parameters() if id(p) not in self.head_param_ids] + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # TODO policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) # TODO 50k + + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + + + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + # ===================== END: 目标熵正则化初始化 ===================== + + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: @@ -379,6 +643,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch + + # --- NEW: Calculate current epsilon for policy --- + if self.policy_ls_eps_start > 0: + progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + else: + current_policy_label_eps = 0.0 + # Prepare observations based on frame stack number if self._cfg.model.frame_stack_num > 1: obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) @@ -407,8 +679,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in transformed_target_value = scalar_transform(target_value) # Convert to categorical distributions - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) + # target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) # Prepare batch for GPT model batch_for_gpt = {} @@ -431,6 +706,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['target_value'] = target_value_categorical[:, :-1] batch_for_gpt['target_policy'] = target_policy[:, :-1] + # ==================== START MODIFICATION 1 ==================== + # Pass the original scalar target_value to compute_loss for priority calculation. + batch_for_gpt['scalar_target_value'] = target_value + # ===================== END MODIFICATION 1 ===================== + # Extract valid target policy data and compute entropy valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) @@ -438,13 +718,54 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, global_step=train_iter, current_policy_label_eps=current_policy_label_eps, ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. - weighted_total_loss = losses.loss_total + # ==================== [修改] 集成范数监控逻辑 ==================== + norm_log_dict = {} + # 检查是否达到监控频率 + if self._cfg.monitor_norm_freq > 0 and train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0): + with torch.no_grad(): + # 1. 监控模型参数范数 + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. 监控中间张量 x (Transformer的输出) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x 的形状为 (B, T, E) + # 计算每个 token 的 L2 范数 + token_norms = intermediate_x.norm(p=2, dim=-1) + + # 记录这些范数的统计数据 + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + # ================================================================= + + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + + # weighted_total_loss = losses.loss_total + # TODO: + weighted_total_loss = (weights * losses.loss_total).mean() + for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value + # 从 losses 对象中提取策略熵 + policy_entropy = self.intermediate_losses['policy_entropy'] + + + + + obs_loss = self.intermediate_losses['loss_obs'] reward_loss = self.intermediate_losses['loss_rewards'] policy_loss = self.intermediate_losses['loss_policy'] @@ -453,12 +774,88 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in perceptual_loss = self.intermediate_losses['perceptual_loss'] orig_policy_loss = self.intermediate_losses['orig_policy_loss'] policy_entropy = self.intermediate_losses['policy_entropy'] + + # ==================== START: 目标熵正则化更新逻辑 ==================== + alpha_loss = None + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() + + # # --- 更新 log_alpha --- + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内 + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + # --- 更新 log_alpha --- + # 仅在需要更新时执行 (与主模型的梯度累积同步) + # if (train_iter + 1) % self.accumulation_steps == 0: + # self.alpha_optimizer.zero_grad() + # alpha_loss.backward() + # self.alpha_optimizer.step() + + # # [可选但推荐] 增加裁剪作为安全措施 + # with torch.no_grad(): + # self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) # 限制alpha在合理范围 + + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重 + # ... 如果还有其他损失项,也加进来 ... + ) + weighted_total_loss = (weights * total_loss).mean() + # ===================== END: 目标熵正则化更新逻辑 ===================== + first_step_losses = self.intermediate_losses['first_step_losses'] middle_step_losses = self.intermediate_losses['middle_step_losses'] last_step_losses = self.intermediate_losses['last_step_losses'] dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] + latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] + + + logits_value_mean=self.intermediate_losses['logits_value_mean'] + logits_value_max=self.intermediate_losses['logits_value_max'] + logits_value_min=self.intermediate_losses['logits_value_min'] + + logits_policy_mean=self.intermediate_losses['logits_policy_mean'] + logits_policy_max=self.intermediate_losses['logits_policy_max'] + logits_policy_min=self.intermediate_losses['logits_policy_min'] + + + temperature_value=self.intermediate_losses['temperature_value'] + temperature_reward=self.intermediate_losses['temperature_reward'] + temperature_policy=self.intermediate_losses['temperature_policy'] assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -470,8 +867,125 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Scale the loss by the number of accumulation steps weighted_total_loss = weighted_total_loss / self.accumulation_steps + + if self._cfg.gradient_scale: + # ============================================================== + # START OF THE FIX: Add gradient scaling just like in MuZero + # ============================================================== + # This is the key to stabilizing the latent norm. It averages the gradients + # accumulated over the unroll steps, preventing the exploding gradient problem + # in the recurrent world model (Transformer). + gradient_scale = 1.0 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + # ============================================================== + # END OF THE FIX + # ============================================================== + + weighted_total_loss.backward() + # ================================================================= + # 高效的权重裁剪实现 + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + # if self.latent_norm_clip_threshold > 0 and 'obs_embeddings' in losses.intermediate_losses: + # obs_embeddings = losses.intermediate_losses['obs_embeddings'] + # if obs_embeddings is not None: + # max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + # if max_latent_norm > self.latent_norm_clip_threshold: + # scale_factor = self.latent_norm_clip_threshold / max_latent_norm.item() + # print(f"[Encoder-Clip] Max latent norm {max_latent_norm.item():.2f} > {self.latent_norm_clip_threshold}. Scaling encoder weights by {scale_factor:.4f}.") + # # 调用高效的向量化函数 + # scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + # 2. Value/Reward-Head-Clip + if self.logit_clip_threshold > 0: + logits_value = losses.intermediate_losses.get('logits_value') + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_value is not None and logits_reward is not None: + max_abs_logit = max(logits_value.abs().max(), logits_reward.abs().max()) + if max_abs_logit > self.logit_clip_threshold: + scale_factor = self.logit_clip_threshold / max_abs_logit.item() + print(f"[Value-Reward-Head-Clip] Max abs logit {max_abs_logit.item():.2f} > {self.logit_clip_threshold}. Scaling head weights by {scale_factor:.4f}.") + # 分别对两个头进行缩放 + scale_module_weights_vectorized(self._model.world_model.head_value, scale_factor) + scale_module_weights_vectorized(self._model.world_model.head_rewards, scale_factor) + + # 3. Policy-Head-Clip + policy_logit_clip_threshold = self._cfg.get('policy_logit_clip_threshold', 5) + if policy_logit_clip_threshold > 0: + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + max_policy_logit = logits_policy.max() + if max_policy_logit > policy_logit_clip_threshold: + scale_factor = policy_logit_clip_threshold / max_policy_logit.item() + print(f"[Policy-Head-Clip] Max policy logit {max_policy_logit.item():.4f} > {policy_logit_clip_threshold}. Scaling policy head weights by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.head_policy, scale_factor) + + + # # ======================= 学习率真实性检查 ======================= + # if (train_iter % 1000) == 0: + # print("\n--- Optimizer Learning Rate Analysis ---") + # # self._optimizer_world_model 是唯一的优化器 + # for i, param_group in enumerate(self._optimizer_world_model.param_groups): + # # configure_optimizers_nanogpt 可能会创建多个参数组(例如,一个用于带权重衰减的参数,一个用于不带的) + # print(f" Param Group {i}: LR = {param_group['lr']:.6f}") + # # ================================================================= + + # ======================= 梯度检查代码 ======================= + # 我们可以只关注 Encoder 的梯度 + encoder = self._learn_model.world_model.tokenizer.encoder + total_grad_norm = 0.0 + + # if (train_iter % 5000) == 0: + if (train_iter % 10000) == 0: # 10k + # if (train_iter % 1) == 0: + print(f"\n--- Gradient Analysis for Step {train_iter} ---") + for name, param in encoder.named_parameters(): + if param.grad is not None: + grad_norm = param.grad.norm().item() + total_grad_norm += grad_norm ** 2 + + # 打印每一层的梯度范数,以定位问题层 + print(f" Layer: {name} | Grad Norm: {grad_norm:.6f}") + else: + print(f" Layer: {name} | Grad is None") + + total_grad_norm = total_grad_norm ** 0.5 + print(f"--- Total Grad Norm for Encoder: {total_grad_norm:.6f} ---\n") + # ============================================================= + + # Check if the current iteration completes an accumulation cycle if (train_iter + 1) % self.accumulation_steps == 0: # Analyze gradient norms if simulation normalization analysis is enabled @@ -485,6 +999,23 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value ) + # total_grad_norm_before_clip_wm = torch.tensor(0.) + + + # # 3. 对两组参数分别进行梯度裁剪 + # # - a. 对预测头应用一个更严格(更小)的裁剪阈值 + # # 您需要在配置文件中新增 `head_grad_clip_value`,例如设置为 1.0 或 0.5 + # head_grad_norm = torch.nn.utils.clip_grad_norm_( + # self.head_params, self._cfg.get('head_grad_clip_value', 1.0) # 示例:严格的阈值 + # ) + # # - b. 对主干网络应用一个相对宽松的裁剪阈值 + # # 您可以在配置文件中新增 `backbone_grad_clip_value`,例如设置为 10.0 + # backbone_grad_norm = torch.nn.utils.clip_grad_norm_( + # self.backbone_params, self._cfg.get('backbone_grad_clip_value', 10.0) # 示例:标准的阈值 + # ) + head_grad_norm = torch.tensor(0.) + backbone_grad_norm = torch.tensor(0.) + # Synchronize gradients across multiple GPUs if enabled if self._cfg.multi_gpu: @@ -499,6 +1030,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in else: total_grad_norm_before_clip_wm = torch.tensor(0.) + + # 以前clip的位置 ========== + + # Update learning rate scheduler if applicable if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: self.lr_scheduler.step() @@ -547,21 +1082,57 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'target_policy_entropy': average_target_policy_entropy.item(), 'reward_loss': reward_loss.item(), 'value_loss': value_loss.item(), - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # ==================== START MODIFICATION 3 ==================== + # Add value_priority to the log dictionary. + 'value_priority': value_priority_np.mean().item(), + 'value_priority_orig': value_priority_np, + # ===================== END MODIFICATION 3 ===================== 'target_reward': target_reward.mean().item(), 'target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + "head_grad_norm":head_grad_norm, + "backbone_grad_norm":backbone_grad_norm, + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/latent_action_l2_norms': latent_action_l2_norms.item(), 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, + + "logits_value_mean":logits_value_mean, + "logits_value_max":logits_value_max, + "logits_value_min":logits_value_min, + "logits_policy_mean":logits_policy_mean, + "logits_policy_max":logits_policy_max, + "logits_policy_min":logits_policy_min, + + "temperature_value":temperature_value, + "temperature_reward":temperature_reward, + "temperature_policy":temperature_policy, + + "current_policy_label_eps":current_policy_label_eps, } + # ==================== [修改] 将范数监控结果合并到日志中 ==================== + if norm_log_dict: + return_log_dict.update(norm_log_dict) + # ======================================================================= + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + + return_log_dict['alpha_loss'] = alpha_loss.item() + + # ==================== START: 添加新日志项 ==================== + if self.use_encoder_clip_annealing: + return_log_dict['current_encoder_clip_value'] = current_clip_value + # ===================== END: 添加新日志项 ===================== + if self._cfg.use_wandb: wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) @@ -583,11 +1154,14 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model - + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + self._collect_mcts_temperature = 1. self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num @@ -598,6 +1172,21 @@ def _init_collect(self) -> None: self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] + # def _apply_temperature_scaling_for_inference(self, policy_logits, value_logits, reward_logits): + # """ 辅助函数,在推理时应用可学习的温度 """ + # if self._cfg.model.world_model_cfg.use_temperature_scaling: + # with torch.no_grad(): + # T_policy = 1.0 + F.softplus(self._model.world_model.log_temp_policy) + # T_value = 1.0 + F.softplus(self._model.world_model.log_temp_value) + # T_reward = 1.0 + F.softplus(self._model.world_model.log_temp_reward) + + # policy_logits /= (T_policy + 1e-8) + # value_logits /= (T_value + 1e-8) + # if not isinstance(reward_logits, list): + # reward_logits /= (T_reward + 1e-8) + + # return policy_logits, value_logits, reward_logits + # @profile def _forward_collect( self, @@ -647,6 +1236,15 @@ def _forward_collect( network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + # if self._cfg.model.world_model_cfg.use_temperature_scaling: + # # ==================== 关键修改点 ==================== + # # 在将 logits 送入 MCTS 之前,应用可学习的温度 + # policy_logits, pred_values, reward_roots = self._apply_temperature_scaling_for_inference( + # policy_logits, pred_values, reward_roots + # ) + # # ==================================================== + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -695,7 +1293,8 @@ def _forward_collect( # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - next_latent_state = next_latent_state_with_env[i][action] + # next_latent_state = next_latent_state_with_env[i][action] # eps_collect have ker bug + # next_latent_state = None if self._cfg.model.world_model_cfg.obs_type == 'text': # Output the plain text content decoded by the decoder from the next latent state @@ -726,10 +1325,10 @@ def _forward_collect( self.last_batch_action = batch_action # ========= TODO: for muzero_segment_collector now ========= - if active_collect_env_num < self.collector_env_num: + if active_collect_env_num < self.collector_env_num: # 先有环境done,再到下一步的forward出现这个这个条件满足 print('==========collect_forward============') print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') - self._reset_collect(reset_init_data=True) + self._reset_collect(reset_init_data=True) # TODO(pu): 所有环境全部重置是否合理呢? if getattr(self._cfg, 'sample_type', '') == 'episode': print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') @@ -741,10 +1340,15 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': @@ -789,6 +1393,14 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + # if self._cfg.model.world_model_cfg.use_temperature_scaling: + # # ==================== 关键修改点 ==================== + # # 在将 logits 送入 MCTS 之前,应用可学习的温度 + # policy_logits, pred_values, reward_roots = self._apply_temperature_scaling_for_inference( + # policy_logits, pred_values, reward_roots + # ) + # # ==================================================== + # if not in training, obtain the scalars of the value/reward pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() @@ -826,7 +1438,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] # Predict the next latent state based on the selected action and policy - next_latent_state = next_latent_state_with_env[i][action] + # next_latent_state = next_latent_state_with_env[i][action] # eps_collect have ker bug + # next_latent_state = None if self._cfg.model.world_model_cfg.obs_type == 'text': # Output the plain text content decoded by the decoder from the next latent state @@ -871,15 +1484,52 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing. + # The collector calls `_policy.reset([env_id])` when an episode is done, + # which results in `current_steps` being None and `env_id` being a list. + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + # TODO + # The recurrent cache is global, which is problematic. + # A full clear is heavy-handed but safer than leaving stale entries. + # world_model.past_kv_cache_recurrent_infer.clear() + # if hasattr(world_model, 'keys_values_wm_list'): + # world_model.keys_values_wm_list.clear() + # torch.cuda.empty_cache() + # --- END ROBUST FIX --- + + + # # Return immediately if env_id is None or a list + # if env_id is None or isinstance(env_id, list): + # return # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # TODO:========== + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 40 + + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the collect model's world model @@ -892,8 +1542,8 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in # Free up GPU memory torch.cuda.empty_cache() - print('collector: collect_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + print(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()') + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: """ @@ -915,15 +1565,47 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # if env_id is None or isinstance(env_id, list): + # return # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # TODO:========== + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 40 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length - # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + # # Clear caches if the current steps are a multiple of the clear interval + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the eval model's world model @@ -945,10 +1627,12 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - return [ + base_vars = [ 'analysis/dormant_ratio_encoder', 'analysis/dormant_ratio_world_model', 'analysis/latent_state_l2_norms', + 'analysis/latent_action_l2_norms', + 'analysis/l2_norm_before', 'analysis/l2_norm_after', 'analysis/grad_norm_before', @@ -986,15 +1670,58 @@ def _monitor_vars_learn(self) -> List[str]: 'reward_loss', 'value_loss', 'consistency_loss', + # ==================== START MODIFICATION 4 ==================== 'value_priority', + # ===================== END MODIFICATION 4 ===================== 'target_reward', 'target_value', 'total_grad_norm_before_clip_wm', + "head_grad_norm", + "backbone_grad_norm", # tokenizer 'commitment_loss', 'reconstruction_loss', 'perceptual_loss', + + + "logits_value_mean", + "logits_value_max", + "logits_value_min", + "logits_policy_mean", + "logits_policy_max", + "logits_policy_min", + + "temperature_value", + "temperature_reward", + "temperature_policy", + "current_policy_label_eps", + 'adaptive_alpha', + "adaptive_target_entropy_ratio", + 'alpha_loss', + "current_encoder_clip_value", + ] + + # ==================== [新增] 添加范数和中间张量监控变量 ==================== + norm_vars = [ + # 模块总范数 + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + 'norm/head_value/_total_norm', + 'norm/head_reward/_total_norm', + 'norm/head_policy/_total_norm', + # 中间张量 x 的统计信息 + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', ] + # 注意:我们不把每一层的范数都加到这里,因为数量太多会导致日志混乱。 + # 在实践中,如果通过总范数发现问题,可以临时在TensorBoard中搜索特定层的范数, + # 或者在本地打印 `norm_log_dict` 来进行详细分析。 + # wandb等工具可以更好地处理大量的动态指标。 + # ======================================================================== + + return base_vars + norm_vars def _state_dict_learn(self) -> Dict[str, Any]: """ @@ -1003,11 +1730,16 @@ def _state_dict_learn(self) -> Dict[str, Any]: Returns: - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ - return { + state_dict = { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), 'optimizer_world_model': self._optimizer_world_model.state_dict(), } + # ==================== START: 保存Alpha优化器状态 ==================== + if self.use_adaptive_entropy_weight: + state_dict['alpha_optimizer'] = self.alpha_optimizer.state_dict() + # ===================== END: 保存Alpha优化器状态 ===================== + return state_dict def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ @@ -1018,7 +1750,12 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ==================== START: 加载Alpha优化器状态 ==================== + if self.use_adaptive_entropy_weight and 'alpha_optimizer' in state_dict: + self.alpha_optimizer.load_state_dict(state_dict['alpha_optimizer']) + # ===================== END: 加载Alpha优化器状态 ===================== def recompute_pos_emb_diff_and_clear_cache(self) -> None: """ diff --git a/lzero/policy/unizero_bkp20250819.py b/lzero/policy/unizero_bkp20250819.py new file mode 100644 index 000000000..dd6956be6 --- /dev/null +++ b/lzero/policy/unizero_bkp20250819.py @@ -0,0 +1,1076 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import wandb +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs, \ + prepare_obs_stack_for_unizero +from lzero.policy.muzero import MuZeroPolicy +from .utils import configure_optimizers_nanogpt + + +@POLICY_REGISTRY.register('unizero') +class UniZeroPolicy(MuZeroPolicy): + """ + Overview: + The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning + with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero policy. + config = dict( + type='unizero', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='BN', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy loss. + policy_entropy_weight=0, + final_norm_option_in_encoder="SimNorm", # "SimNorm"对应"group_kl", "LayerNorm"对应"mse", + final_norm_option_in_obs_head="SimNorm", + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.025, + # (bool) Whether to use Rotary Position Embedding (RoPE) for relative position encoding. + # If False, nn.Embedding is used for absolute position encoding. + # For more details on RoPE, refer to the author's blog: https://spaces.ac.cn/archives/8265/ + # TODO: If you want to use rotary_emb in an environment, you need to include the timestep as a return key from the environment. + rotary_emb=False, + # (int) The base value for calculating RoPE angles. Commonly set to 10000. + rope_theta=10000, + # (int) The maximum sequence length for position encoding. + max_seq_len=8192, + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(2e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=20, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # # (int) the number of simulations in MCTS for renalyze. + num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use the cosine learning rate decay. + cos_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + piecewise_decay_lr_scheduler=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(5e4), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + # (int) The number of steps to accumulate gradients before performing an optimization step. + accumulation_steps=1, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.unizero_model.MuZeroModel`` + """ + return 'UniZeroModel', ['lzero.model.unizero_model'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps + self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is greater than or equal to 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + if self._cfg.target_model_update_option=="soft": + # NOTE: soft target + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + elif self._cfg.target_model_update_option=="hard": + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + # assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + # assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + if self._cfg.use_wandb: + # TODO: add the model to wandb + wandb.watch(self._learn_model.representation_network, log="all") + + self.accumulation_steps = self._cfg.accumulation_steps + + # @profile + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch, train_iter = data + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num > 1: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) # TODO: optimize + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space + timestep_batch = torch.from_numpy(timestep_batch).to(self._cfg.device).unsqueeze( + -1).long() + data_list = [mask_batch, target_reward, target_value, target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare batch for GPT model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['timestep'] = timestep_batch.squeeze(-1) + + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean() + + # Update world model + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle + ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. + + weighted_total_loss = losses.loss_total + for loss_name, loss_value in losses.intermediate_losses.items(): + self.intermediate_losses[f"{loss_name}"] = loss_value + + obs_loss = self.intermediate_losses['loss_obs'] + reward_loss = self.intermediate_losses['loss_rewards'] + policy_loss = self.intermediate_losses['loss_policy'] + value_loss = self.intermediate_losses['loss_value'] + latent_recon_loss = self.intermediate_losses['latent_recon_loss'] + perceptual_loss = self.intermediate_losses['perceptual_loss'] + orig_policy_loss = self.intermediate_losses['orig_policy_loss'] + policy_entropy = self.intermediate_losses['policy_entropy'] + first_step_losses = self.intermediate_losses['first_step_losses'] + middle_step_losses = self.intermediate_losses['middle_step_losses'] + last_step_losses = self.intermediate_losses['last_step_losses'] + dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] + dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] + latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] + latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] + + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + # Core learning model update step + # Reset gradients at the start of each accumulation cycle + if (train_iter % self.accumulation_steps) == 0: + self._optimizer_world_model.zero_grad() + + # Scale the loss by the number of accumulation steps + weighted_total_loss = weighted_total_loss / self.accumulation_steps + + if self._cfg.gradient_scale: + # ============================================================== + # START OF THE FIX: Add gradient scaling just like in MuZero + # ============================================================== + # This is the key to stabilizing the latent norm. It averages the gradients + # accumulated over the unroll steps, preventing the exploding gradient problem + # in the recurrent world model (Transformer). + gradient_scale = 1.0 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + # ============================================================== + # END OF THE FIX + # ============================================================== + + + weighted_total_loss.backward() + + # Check if the current iteration completes an accumulation cycle + if (train_iter + 1) % self.accumulation_steps == 0: + # Analyze gradient norms if simulation normalization analysis is enabled + if self._cfg.analysis_sim_norm: + # Clear previous analysis results to prevent memory overflow + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + # Clip gradients to prevent exploding gradients + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( + self._learn_model.world_model.parameters(), self._cfg.grad_clip_value + ) + + # Synchronize gradients across multiple GPUs if enabled + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + # Update model parameters + self._optimizer_world_model.step() + + # Clear CUDA cache if using gradient accumulation + if self.accumulation_steps > 1: + torch.cuda.empty_cache() + else: + total_grad_norm_before_clip_wm = torch.tensor(0.) + + # Update learning rate scheduler if applicable + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Update the target model with the current model's parameters + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + return_log_dict = { + 'analysis/first_step_loss_value': first_step_losses['loss_value'].item(), + 'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(), + 'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(), + 'analysis/first_step_loss_obs': first_step_losses['loss_obs'].item(), + + 'analysis/middle_step_loss_value': middle_step_losses['loss_value'].item(), + 'analysis/middle_step_loss_policy': middle_step_losses['loss_policy'].item(), + 'analysis/middle_step_loss_rewards': middle_step_losses['loss_rewards'].item(), + 'analysis/middle_step_loss_obs': middle_step_losses['loss_obs'].item(), + + 'analysis/last_step_loss_value': last_step_losses['loss_value'].item(), + 'analysis/last_step_loss_policy': last_step_losses['loss_policy'].item(), + 'analysis/last_step_loss_rewards': last_step_losses['loss_rewards'].item(), + 'analysis/last_step_loss_obs': last_step_losses['loss_obs'].item(), + + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'obs_loss': obs_loss.item(), + 'latent_recon_loss': latent_recon_loss.item(), + 'perceptual_loss': perceptual_loss.item(), + 'policy_loss': policy_loss.item(), + 'orig_policy_loss': orig_policy_loss.item(), + 'policy_entropy': policy_entropy.item(), + 'target_policy_entropy': average_target_policy_entropy.item(), + 'reward_loss': reward_loss.item(), + 'value_loss': value_loss.item(), + # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + 'target_reward': target_reward.mean().item(), + 'target_value': target_value.mean().item(), + 'transformed_target_reward': transformed_target_reward.mean().item(), + 'transformed_target_value': transformed_target_value.mean().item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), + 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/latent_action_l2_norms': latent_action_l2_norms.item(), + 'analysis/l2_norm_before': self.l2_norm_before, + 'analysis/l2_norm_after': self.l2_norm_after, + 'analysis/grad_norm_before': self.grad_norm_before, + 'analysis/grad_norm_after': self.grad_norm_after, + } + + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_log_dict + + def monitor_weights_and_grads(self, model): + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(mcts_collect_cfg) + else: + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # @profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: List = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.ndarray = None, + timestep: List = [0] + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - timestep (:obj:`list`): The step index of the env in one episode. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + - timestep: :math:`(N, 1)`, where N is the number of collect_env. + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + + next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._collect_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + + # ============== TODO: only for visualize ============== + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== TODO: only for visualize ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: for muzero_segment_collector now ========= + if active_collect_env_num < self.collector_env_num: + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True) + if getattr(self._cfg, 'sample_type', '') == 'episode': + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(mcts_eval_cfg) + else: + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], + ready_env_id: np.array = None, timestep: List = [0]) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to eval. + - timestep (:obj:`list`): The step index of the env in one episode. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of eval_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of eval_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of eval_env. + - to_play: :math:`(N, 1)`, where N is the number of eval_env. + - ready_env_id: None + - timestep: :math:`(N, 1)`, where N is the number of eval_env. + + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + # if not in training, obtain the scalars of the value/reward + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + next_latent_state_with_env = self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # Predict the next latent state based on the selected action and policy + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._eval_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the collection process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data + will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the collect model's world model + world_model = self._collect_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('collector: collect_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the evaluation process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, + the initial data will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the eval model's world model + world_model = self._eval_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('evaluator: eval_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'analysis/dormant_ratio_encoder', + 'analysis/dormant_ratio_world_model', + 'analysis/latent_state_l2_norms', + 'analysis/latent_action_l2_norms', + + 'analysis/l2_norm_before', + 'analysis/l2_norm_after', + 'analysis/grad_norm_before', + 'analysis/grad_norm_after', + + 'analysis/first_step_loss_value', + 'analysis/first_step_loss_policy', + 'analysis/first_step_loss_rewards', + 'analysis/first_step_loss_obs', + + 'analysis/middle_step_loss_value', + 'analysis/middle_step_loss_policy', + 'analysis/middle_step_loss_rewards', + 'analysis/middle_step_loss_obs', + + 'analysis/last_step_loss_value', + 'analysis/last_step_loss_policy', + 'analysis/last_step_loss_rewards', + 'analysis/last_step_loss_obs', + + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'cur_lr_tokenizer', + + 'weighted_total_loss', + 'obs_loss', + 'policy_loss', + 'orig_policy_loss', + 'policy_entropy', + 'latent_recon_loss', + 'target_policy_entropy', + 'reward_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_reward', + 'target_value', + 'total_grad_norm_before_clip_wm', + # tokenizer + 'commitment_loss', + 'reconstruction_loss', + 'perceptual_loss', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + for model in [self._collect_model, self._target_model]: + if not self._cfg.model.world_model_cfg.rotary_emb: + # If rotary_emb is False, nn.Embedding is used for absolute position encoding. + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/lzero/policy/unizero_bkp20250917.py b/lzero/policy/unizero_bkp20250917.py new file mode 100644 index 000000000..0cae29611 --- /dev/null +++ b/lzero/policy/unizero_bkp20250917.py @@ -0,0 +1,1448 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import wandb +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs, \ + prepare_obs_stack_for_unizero +from lzero.policy.muzero import MuZeroPolicy +from .utils import configure_optimizers_nanogpt + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + """ + # 1. 定义需要特殊处理的参数 + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + + # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads + transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} + tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + + # Heads的参数是那些既不属于transformer也不属于tokenizer的 + head_params = { + pn: p for pn, p in param_dict.items() + if 'transformer' not in pn and 'tokenizer' not in pn + } + + # 3. 为每组设置不同的优化器参数(特别是学习率) + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': list(transformer_params.values()), + 'lr': learning_rate * 0.1, # 为Transformer主干设置一个较小的学习率,例如 1e-5 + 'weight_decay': weight_decay + }, + { + 'params': list(tokenizer_params.values()), + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': list(head_params.values()), + 'lr': learning_rate, # Heads也使用基础学习率 + 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + } + ] + + print("--- Optimizer Groups ---") + print(f"Transformer LR: {learning_rate * 0.1}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + + +@POLICY_REGISTRY.register('unizero') +class UniZeroPolicy(MuZeroPolicy): + """ + Overview: + The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning + with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero policy. + config = dict( + type='unizero', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='BN', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy loss. + policy_entropy_weight=0, + final_norm_option_in_encoder="SimNorm", # "SimNorm"对应"group_kl", "LayerNorm"对应"mse", + final_norm_option_in_obs_head="SimNorm", + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.025, + # (bool) Whether to use Rotary Position Embedding (RoPE) for relative position encoding. + # If False, nn.Embedding is used for absolute position encoding. + # For more details on RoPE, refer to the author's blog: https://spaces.ac.cn/archives/8265/ + # TODO: If you want to use rotary_emb in an environment, you need to include the timestep as a return key from the environment. + rotary_emb=False, + # (int) The base value for calculating RoPE angles. Commonly set to 10000. + rope_theta=10000, + # (int) The maximum sequence length for position encoding. + max_seq_len=8192, + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(2e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=20, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # # (int) the number of simulations in MCTS for renalyze. + num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use the cosine learning rate decay. + cos_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + piecewise_decay_lr_scheduler=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(5e4), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + # (int) The number of steps to accumulate gradients before performing an optimization step. + accumulation_steps=1, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.unizero_model.MuZeroModel`` + """ + return 'UniZeroModel', ['lzero.model.unizero_model'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + + + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps + self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) + + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is greater than or equal to 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + if self._cfg.target_model_update_option=="soft": + # NOTE: soft target + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + elif self._cfg.target_model_update_option=="hard": + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + # assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + # assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + if self._cfg.use_wandb: + # TODO: add the model to wandb + wandb.watch(self._learn_model.representation_network, log="all") + + self.accumulation_steps = self._cfg.accumulation_steps + # 从配置中获取阈值,如果未设置则使用一个合理的默认值(例如20.0) + # 设置为0或负数则禁用此功能 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) # TODO + + + # 从配置中获取阈值,例如 15.0 或 20.0 + self.logit_clip_threshold = self._cfg.get('logit_clip_threshold', 10.0) + # 1. 获取 world_model 的引用,方便后续操作 + world_model = self._learn_model.world_model + # 2. 将参数明确地分为两组:预测头 (heads) 和 主干网络 (backbone) + # - a. 获取所有预测头的参数 + self.head_params = list(world_model.head_value.parameters()) + \ + list(world_model.head_rewards.parameters()) + \ + list(world_model.head_policy.parameters()) + # 如果有其他头,也一并加入 + # - b. 为了高效分离,我们使用参数的ID + self.head_param_ids = {id(p) for p in self.head_params} + # - c. 获取主干网络的参数(所有不在 head_param_ids 中的参数) + self.backbone_params = [p for p in world_model.parameters() if id(p) not in self.head_param_ids] + + # @profile + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch, train_iter = data + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num > 1: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) # TODO: optimize + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space + timestep_batch = torch.from_numpy(timestep_batch).to(self._cfg.device).unsqueeze( + -1).long() + data_list = [mask_batch, target_reward, target_value, target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # TODO + # ==================== 核心修复:标签平滑 ==================== + # alpha 是平滑系数,一个小的超参数,例如 0.01 或 0.1 + # alpha = 0.1 + # num_classes = target_value_categorical.shape[-1] + # # (1 - alpha) * original_target + alpha / num_classes + # target_value_categorical = target_value_categorical * (1 - alpha) + (alpha / num_classes) + # target_reward_categorical = target_reward_categorical * (1 - alpha) + (alpha / num_classes) + # ============================================================= + + # Prepare batch for GPT model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['timestep'] = timestep_batch.squeeze(-1) + + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # ==================== START MODIFICATION 1 ==================== + # Pass the original scalar target_value to compute_loss for priority calculation. + batch_for_gpt['scalar_target_value'] = target_value + # ===================== END MODIFICATION 1 ===================== + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean() + + # Update world model + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, global_step=train_iter + ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + + # weighted_total_loss = losses.loss_total + # TODO: + weighted_total_loss = (weights * losses.loss_total).mean() + + for loss_name, loss_value in losses.intermediate_losses.items(): + self.intermediate_losses[f"{loss_name}"] = loss_value + + obs_loss = self.intermediate_losses['loss_obs'] + reward_loss = self.intermediate_losses['loss_rewards'] + policy_loss = self.intermediate_losses['loss_policy'] + value_loss = self.intermediate_losses['loss_value'] + latent_recon_loss = self.intermediate_losses['latent_recon_loss'] + perceptual_loss = self.intermediate_losses['perceptual_loss'] + orig_policy_loss = self.intermediate_losses['orig_policy_loss'] + policy_entropy = self.intermediate_losses['policy_entropy'] + first_step_losses = self.intermediate_losses['first_step_losses'] + middle_step_losses = self.intermediate_losses['middle_step_losses'] + last_step_losses = self.intermediate_losses['last_step_losses'] + dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] + dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] + latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] + latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] + + + logits_value_mean=self.intermediate_losses['logits_value_mean'] + logits_value_max=self.intermediate_losses['logits_value_max'] + logits_value_min=self.intermediate_losses['logits_value_min'] + + logits_policy_mean=self.intermediate_losses['logits_policy_mean'] + logits_policy_max=self.intermediate_losses['logits_policy_max'] + logits_policy_min=self.intermediate_losses['logits_policy_min'] + + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + # Core learning model update step + # Reset gradients at the start of each accumulation cycle + if (train_iter % self.accumulation_steps) == 0: + self._optimizer_world_model.zero_grad() + + # Scale the loss by the number of accumulation steps + weighted_total_loss = weighted_total_loss / self.accumulation_steps + + if self._cfg.gradient_scale: + # ============================================================== + # START OF THE FIX: Add gradient scaling just like in MuZero + # ============================================================== + # This is the key to stabilizing the latent norm. It averages the gradients + # accumulated over the unroll steps, preventing the exploding gradient problem + # in the recurrent world model (Transformer). + gradient_scale = 1.0 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + # ============================================================== + # END OF THE FIX + # ============================================================== + + + weighted_total_loss.backward() + + # # ======================= 学习率真实性检查 ======================= + # if (train_iter % 1000) == 0: + # print("\n--- Optimizer Learning Rate Analysis ---") + # # self._optimizer_world_model 是唯一的优化器 + # for i, param_group in enumerate(self._optimizer_world_model.param_groups): + # # configure_optimizers_nanogpt 可能会创建多个参数组(例如,一个用于带权重衰减的参数,一个用于不带的) + # print(f" Param Group {i}: LR = {param_group['lr']:.6f}") + # # ================================================================= + + # ======================= 梯度检查代码 ======================= + # 我们可以只关注 Encoder 的梯度 + encoder = self._learn_model.world_model.tokenizer.encoder + total_grad_norm = 0.0 + + # if (train_iter % 5000) == 0: + if (train_iter % 10000) == 0: # 10k + # if (train_iter % 1) == 0: + print(f"\n--- Gradient Analysis for Step {train_iter} ---") + for name, param in encoder.named_parameters(): + if param.grad is not None: + grad_norm = param.grad.norm().item() + total_grad_norm += grad_norm ** 2 + + # 打印每一层的梯度范数,以定位问题层 + print(f" Layer: {name} | Grad Norm: {grad_norm:.6f}") + else: + print(f" Layer: {name} | Grad is None") + + total_grad_norm = total_grad_norm ** 0.5 + print(f"--- Total Grad Norm for Encoder: {total_grad_norm:.6f} ---\n") + # ============================================================= + + + # Check if the current iteration completes an accumulation cycle + if (train_iter + 1) % self.accumulation_steps == 0: + # Analyze gradient norms if simulation normalization analysis is enabled + if self._cfg.analysis_sim_norm: + # Clear previous analysis results to prevent memory overflow + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + # Clip gradients to prevent exploding gradients + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( + self._learn_model.world_model.parameters(), self._cfg.grad_clip_value + ) + # total_grad_norm_before_clip_wm = torch.tensor(0.) + + + # # 3. 对两组参数分别进行梯度裁剪 + # # - a. 对预测头应用一个更严格(更小)的裁剪阈值 + # # 您需要在配置文件中新增 `head_grad_clip_value`,例如设置为 1.0 或 0.5 + # head_grad_norm = torch.nn.utils.clip_grad_norm_( + # self.head_params, self._cfg.get('head_grad_clip_value', 1.0) # 示例:严格的阈值 + # ) + # # - b. 对主干网络应用一个相对宽松的裁剪阈值 + # # 您可以在配置文件中新增 `backbone_grad_clip_value`,例如设置为 10.0 + # backbone_grad_norm = torch.nn.utils.clip_grad_norm_( + # self.backbone_params, self._cfg.get('backbone_grad_clip_value', 10.0) # 示例:标准的阈值 + # ) + head_grad_norm = torch.tensor(0.) + backbone_grad_norm = torch.tensor(0.) + + + # Synchronize gradients across multiple GPUs if enabled + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + # Update model parameters + self._optimizer_world_model.step() + + # Clear CUDA cache if using gradient accumulation + if self.accumulation_steps > 1: + torch.cuda.empty_cache() + else: + total_grad_norm_before_clip_wm = torch.tensor(0.) + + # ================================================================= + # Encoder-Clip: Inspired by QK-Clip + # ----------------------------------------------------------------- + # 直接控制Encoder输出的范数,防止其无界增长,以稳定训练。 + # ================================================================= + if self.latent_norm_clip_threshold > 0 and 'obs_embeddings' in losses.intermediate_losses: + with torch.no_grad(): + # 1. 从loss字典中获取已分离的encoder输出 + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is None: + raise ValueError + + # 2. 计算这批数据中,encoder输出L2范数的最大值 + # obs_embeddings 的形状通常是 (B*L, 1, E) 或 (B*L, E) + # 我们在最后一个维度(embedding_dim)上计算范数 + latent_norms = obs_embeddings.norm(p=2, dim=-1) + max_latent_norm = latent_norms.max() + + # 3. 检查最大范数是否超过了我们设定的阈值 + if max_latent_norm > self.latent_norm_clip_threshold: + + # 4. 计算缩放因子 + scale_factor = self.latent_norm_clip_threshold / max_latent_norm.item() + + # (可选) 打印日志,方便调试 + print(f"[Encoder-Clip] Max latent norm {max_latent_norm.item():.2f} > {self.latent_norm_clip_threshold}. Scaling encoder weights by {scale_factor:.4f}.") + + # 5. 将缩放因子应用到Encoder的所有权重上 + encoder = self._model.world_model.tokenizer.encoder + for param in encoder.parameters(): + if param.requires_grad: + param.data.mul_(scale_factor) + + # ================================================================= + # Head-Clip: 直接控制预测头的权重 + # ----------------------------------------------------------------- + # 如果Value或Reward的Logits绝对值过大,则按比例缩放对应头的权重。 + # ================================================================= + if self.logit_clip_threshold > 0: + with torch.no_grad(): + # 从模型输出中获取原始的Logits (需要确保WorldModel的forward或compute_loss返回了它们) + # 假设它们存储在 losses.intermediate_losses 中 + logits_value = losses.intermediate_losses.get('logits_value') + logits_reward = losses.intermediate_losses.get('logits_reward') + + if logits_value is not None and logits_reward is not None: + # 计算Value和Reward Logits中的最大绝对值 + max_abs_logit = max(logits_value.abs().max(), logits_reward.abs().max()) + + # 检查是否超过阈值 + if max_abs_logit > self.logit_clip_threshold: + # 计算缩放因子 + scale_factor = self.logit_clip_threshold / max_abs_logit.item() + + print(f"[Value-Reward-Head-Clip] Max abs logit {max_abs_logit.item():.2f} > {self.logit_clip_threshold}. Scaling head weights by {scale_factor:.4f}.") + + # 获取需要裁剪的预测头 + head_value_module = self._model.world_model.head_value + head_reward_module = self._model.world_model.head_rewards + + # 将缩放因子应用到这两个头的所有权重上 + for head_module in [head_value_module, head_reward_module]: + for param in head_module.parameters(): + if param.requires_grad: + param.data.mul_(scale_factor) + # ================================================================= + + + # ================================================================= + # 【新功能】Policy-Head-Clip: 直接控制Policy预测头的权重以保持探索 + # ----------------------------------------------------------------- + # 此机制的目标是防止策略过早地变得过于确定,从而扼杀探索。 + # 它通过限制Policy Logits的最大正值来实现。 + # ================================================================= + # 从配置中获取策略裁剪阈值,如果未设置则默认为0.1 + policy_logit_clip_threshold = self._cfg.get('policy_logit_clip_threshold', 0.1) + + if policy_logit_clip_threshold > 0: + with torch.no_grad(): + # 1. 从模型输出中获取原始的Policy Logits + # 确保 WorldModel 的 compute_loss 返回了 'logits_policy' + logits_policy = losses.intermediate_losses.get('logits_policy') + + if logits_policy is not None: + # 2. 计算Policy Logits中的最大值 (我们只关心正向的最大值,不关心负值有多大) + max_policy_logit = logits_policy.max() + + # 3. 检查是否超过了我们为“探索性”设定的阈值 + if max_policy_logit > policy_logit_clip_threshold: + # 4. 计算缩放因子 + scale_factor = policy_logit_clip_threshold / max_policy_logit.item() + + # 打印日志,方便调试 + print(f"[Policy-Head-Clip] Max policy logit {max_policy_logit.item():.4f} > {policy_logit_clip_threshold}. Scaling policy head weights by {scale_factor:.4f}.") + + # 5. 获取Policy Head模块 + head_policy_module = self._model.world_model.head_policy + + # 6. 将缩放因子应用到Policy Head的所有权重上 + for param in head_policy_module.parameters(): + if param.requires_grad: + param.data.mul_(scale_factor) + # ================================================================= + + # Update learning rate scheduler if applicable + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Update the target model with the current model's parameters + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + return_log_dict = { + 'analysis/first_step_loss_value': first_step_losses['loss_value'].item(), + 'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(), + 'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(), + 'analysis/first_step_loss_obs': first_step_losses['loss_obs'].item(), + + 'analysis/middle_step_loss_value': middle_step_losses['loss_value'].item(), + 'analysis/middle_step_loss_policy': middle_step_losses['loss_policy'].item(), + 'analysis/middle_step_loss_rewards': middle_step_losses['loss_rewards'].item(), + 'analysis/middle_step_loss_obs': middle_step_losses['loss_obs'].item(), + + 'analysis/last_step_loss_value': last_step_losses['loss_value'].item(), + 'analysis/last_step_loss_policy': last_step_losses['loss_policy'].item(), + 'analysis/last_step_loss_rewards': last_step_losses['loss_rewards'].item(), + 'analysis/last_step_loss_obs': last_step_losses['loss_obs'].item(), + + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'obs_loss': obs_loss.item(), + 'latent_recon_loss': latent_recon_loss.item(), + 'perceptual_loss': perceptual_loss.item(), + 'policy_loss': policy_loss.item(), + 'orig_policy_loss': orig_policy_loss.item(), + 'policy_entropy': policy_entropy.item(), + 'target_policy_entropy': average_target_policy_entropy.item(), + 'reward_loss': reward_loss.item(), + 'value_loss': value_loss.item(), + # ==================== START MODIFICATION 3 ==================== + # Add value_priority to the log dictionary. + 'value_priority': value_priority_np.mean().item(), + 'value_priority_orig': value_priority_np, + # ===================== END MODIFICATION 3 ===================== + 'target_reward': target_reward.mean().item(), + 'target_value': target_value.mean().item(), + 'transformed_target_reward': transformed_target_reward.mean().item(), + 'transformed_target_value': transformed_target_value.mean().item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + "head_grad_norm":head_grad_norm, + "backbone_grad_norm":backbone_grad_norm, + + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), + 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/latent_action_l2_norms': latent_action_l2_norms.item(), + 'analysis/l2_norm_before': self.l2_norm_before, + 'analysis/l2_norm_after': self.l2_norm_after, + 'analysis/grad_norm_before': self.grad_norm_before, + 'analysis/grad_norm_after': self.grad_norm_after, + + "logits_value_mean":logits_value_mean, + "logits_value_max":logits_value_max, + "logits_value_min":logits_value_min, + "logits_policy_mean":logits_policy_mean, + "logits_policy_max":logits_policy_max, + "logits_policy_min":logits_policy_min, + } + + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_log_dict + + def monitor_weights_and_grads(self, model): + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(mcts_collect_cfg) + else: + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # @profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: List = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.ndarray = None, + timestep: List = [0] + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - timestep (:obj:`list`): The step index of the env in one episode. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + - timestep: :math:`(N, 1)`, where N is the number of collect_env. + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + + next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._collect_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + + # ============== TODO: only for visualize ============== + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== TODO: only for visualize ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: for muzero_segment_collector now ========= + if active_collect_env_num < self.collector_env_num: # 先有环境done,再到下一步的forward出现这个这个条件满足 + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True) # TODO(pu): 所有环境全部重置是否合理呢? + if getattr(self._cfg, 'sample_type', '') == 'episode': + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(mcts_eval_cfg) + else: + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], + ready_env_id: np.array = None, timestep: List = [0]) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to eval. + - timestep (:obj:`list`): The step index of the env in one episode. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of eval_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of eval_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of eval_env. + - to_play: :math:`(N, 1)`, where N is the number of eval_env. + - ready_env_id: None + - timestep: :math:`(N, 1)`, where N is the number of eval_env. + + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + # if not in training, obtain the scalars of the value/reward + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + next_latent_state_with_env = self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # Predict the next latent state based on the selected action and policy + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._eval_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the collection process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data + will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing. + # The collector calls `_policy.reset([env_id])` when an episode is done, + # which results in `current_steps` being None and `env_id` being a list. + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + # TODO + # The recurrent cache is global, which is problematic. + # A full clear is heavy-handed but safer than leaving stale entries. + # world_model.past_kv_cache_recurrent_infer.clear() + # if hasattr(world_model, 'keys_values_wm_list'): + # world_model.keys_values_wm_list.clear() + # torch.cuda.empty_cache() + # --- END ROBUST FIX --- + + + # # Return immediately if env_id is None or a list + # if env_id is None or isinstance(env_id, list): + # return + + # Determine the clear interval based on the environment's sample type + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # TODO:========== + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 40 + + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps is not None and current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the collect model's world model + world_model = self._collect_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()') + + + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the evaluation process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, + the initial data will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + + # Return immediately if env_id is None or a list + # if env_id is None or isinstance(env_id, list): + # return + + # Determine the clear interval based on the environment's sample type + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # TODO:========== + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 40 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # # Clear caches if the current steps are a multiple of the clear interval + if current_steps is not None and current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the eval model's world model + world_model = self._eval_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('evaluator: eval_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'analysis/dormant_ratio_encoder', + 'analysis/dormant_ratio_world_model', + 'analysis/latent_state_l2_norms', + 'analysis/latent_action_l2_norms', + + 'analysis/l2_norm_before', + 'analysis/l2_norm_after', + 'analysis/grad_norm_before', + 'analysis/grad_norm_after', + + 'analysis/first_step_loss_value', + 'analysis/first_step_loss_policy', + 'analysis/first_step_loss_rewards', + 'analysis/first_step_loss_obs', + + 'analysis/middle_step_loss_value', + 'analysis/middle_step_loss_policy', + 'analysis/middle_step_loss_rewards', + 'analysis/middle_step_loss_obs', + + 'analysis/last_step_loss_value', + 'analysis/last_step_loss_policy', + 'analysis/last_step_loss_rewards', + 'analysis/last_step_loss_obs', + + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'cur_lr_tokenizer', + + 'weighted_total_loss', + 'obs_loss', + 'policy_loss', + 'orig_policy_loss', + 'policy_entropy', + 'latent_recon_loss', + 'target_policy_entropy', + 'reward_loss', + 'value_loss', + 'consistency_loss', + # ==================== START MODIFICATION 4 ==================== + 'value_priority', + # ===================== END MODIFICATION 4 ===================== + 'target_reward', + 'target_value', + 'total_grad_norm_before_clip_wm', + "head_grad_norm", + "backbone_grad_norm", + # tokenizer + 'commitment_loss', + 'reconstruction_loss', + 'perceptual_loss', + + + "logits_value_mean", + "logits_value_max", + "logits_value_min", + "logits_policy_mean", + "logits_policy_max", + "logits_policy_min", + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + for model in [self._collect_model, self._target_model]: + if not self._cfg.model.world_model_cfg.rotary_emb: + # If rotary_emb is False, nn.Embedding is used for absolute position encoding. + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 4d3b1b740..2b32ff747 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -594,9 +594,17 @@ def collect(self, completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 + # orig + # if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: + # # only for UniZero now + # self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + + # ============ TODO(pu): only for UniZero now ============ if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + if eps_steps_lst[env_id]>self.policy_config.game_segment_length: + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + # print(f"eps_steps_lst[env_id]>self.policy_config.game_segment_length:{eps_steps_lst[env_id]}>{self.policy_config.game_segment_length}") + total_transitions += 1 diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 46cc016bc..1e831897f 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -600,9 +600,15 @@ def collect(self, completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 + + # ============ TODO(pu): only for UniZero now ============ if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # ============ only for UniZero now ============ - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + if eps_steps_lst[env_id]>self.policy_config.game_segment_length: + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + print(f"eps_steps_lst[env_id]>self.policy_config.game_segment_length:{eps_steps_lst[env_id]}>{self.policy_config.game_segment_length}") + + # if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: + # self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) total_transitions += 1 @@ -703,9 +709,9 @@ def collect(self, eps_steps_lst[env_id] = 0 visit_entropies_lst[env_id] = 0 - # Env reset is done by env_manager automatically + # TODO Env reset is done by env_manager automatically # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ - self._policy.reset([env_id]) + self._policy.reset([env_id]) self._reset_stat(env_id) ready_env_id.remove(env_id) diff --git a/zoo/atari/config/atari_efficientzero_segment_config.py b/zoo/atari/config/atari_efficientzero_segment_config.py new file mode 100644 index 000000000..96fe5565c --- /dev/null +++ b/zoo/atari/config/atari_efficientzero_segment_config.py @@ -0,0 +1,149 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +# env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here +env_id = 'MsPacmanNoFrameskip-v4' # You can specify any Atari game here + +action_space_size = atari_env_action_space_map[env_id] + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +update_per_collect = None +# replay_ratio = 0.25 +replay_ratio = 0.1 # 50M envsteps, 5M train iter +# replay_ratio = 0.02 # 50M envsteps, 1M train iter + +num_segments = 8 +game_segment_length = 20 +# game_segment_length = 400 + +collect_num_simulations = 25 +eval_num_simulations = 50 + +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +batch_size = 256 +# max_env_step = int(5e5) +max_env_step = int(50e6) + +reanalyze_ratio = 0. +num_unroll_steps = 5 + +# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. +# buffer_reanalyze_freq = 1 +buffer_reanalyze_freq = 1/50 +# buffer_reanalyze_freq = 1/10000 +# Each reanalyze process will reanalyze sequences ( transitions per sequence) +reanalyze_batch_size = 160 +# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. +reanalyze_partition=0.75 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_efficientzero_config = dict( + exp_name=f'data_efficientzero/{env_id[:-14]}_efficientzero_stack4_H{num_unroll_steps}_seed0', + env=dict( + stop_value=int(1e6), + env_id=env_id, + frame_stack_num=4, + # observation_shape=[4, 64, 64], + # gray_scale=True, + observation_shape=(12, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, + + # observation_shape=[4, 64, 64], + # image_channel=1, + # gray_scale=True, + + # num_res_blocks=1, + # num_channels=64, + num_res_blocks=2, + num_channels=128, + + frame_stack_num=4, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + ), + cuda=True, + env_type='not_board_games', + num_segments=num_segments, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + use_augmentation=True, + # use_priority=False, + use_priority=True, # TODO(pu): test + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + dormant_threshold=0.025, + optim_type='SGD', + policy_entropy_weight=5e-3, + piecewise_decay_lr_scheduler=True, + learning_rate=0.2, + target_update_freq=100, + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, + n_episode=n_episode, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), +) +atari_efficientzero_config = EasyDict(atari_efficientzero_config) +main_config = atari_efficientzero_config + +atari_efficientzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), +) +atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) +create_config = atari_efficientzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + # seeds = [0, 1, 2] # You can add more seed values here + seeds = [0] # You can add more seed values here + + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_efficientzero_20250731/{env_id[:-14]}_efficientzero_stack4_rgb_H{num_unroll_steps}_seed{seed}' + from lzero.entry import train_muzero_segment + train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/atari/config/atari_env_action_space_map.py b/zoo/atari/config/atari_env_action_space_map.py index e2090586d..d40d12f41 100644 --- a/zoo/atari/config/atari_env_action_space_map.py +++ b/zoo/atari/config/atari_env_action_space_map.py @@ -27,4 +27,7 @@ 'SeaquestNoFrameskip-v4': 18, 'BoxingNoFrameskip-v4': 18, 'BreakoutNoFrameskip-v4': 4, + 'SpaceInvadersNoFrameskip-v4': 6, + 'BeamRiderNoFrameskip-v4': 9, + 'GravitarNoFrameskip-v4': 18, }) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_segment_config.py b/zoo/atari/config/atari_muzero_segment_config.py index 4289fb957..dfb9aa436 100644 --- a/zoo/atari/config/atari_muzero_segment_config.py +++ b/zoo/atari/config/atari_muzero_segment_config.py @@ -10,23 +10,38 @@ def main(env_id, seed): collector_env_num = 8 num_segments = 8 game_segment_length = 20 - evaluator_env_num = 3 - num_simulations = 50 + # num_simulations = 50 + + collect_num_simulations = 25 + eval_num_simulations = 50 + update_per_collect = None - replay_ratio = 0.25 + # replay_ratio = 0.25 + replay_ratio = 0.1 # 50M envsteps, 5M train iter + # replay_ratio = 0.02 # 50M envsteps, 1M train iter + + # num_unroll_steps = 5 + num_unroll_steps = 10 # TODO - num_unroll_steps = 5 batch_size = 256 - max_env_step = int(5e5) + # batch_size = 1024 # orig 256 + + # max_env_step = int(5e5) + # max_env_step = int(50e6) + max_env_step = int(2e6) + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1 + # buffer_reanalyze_freq = 1/2 # buffer_reanalyze_freq = 1/10 - buffer_reanalyze_freq = 1/10000 + buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10000000000 # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. - reanalyze_partition=1 + reanalyze_partition=0.75 # =========== for debug =========== # collector_env_num = 2 @@ -43,9 +58,12 @@ def main(env_id, seed): env=dict( stop_value=int(1e6), env_id=env_id, - observation_shape=(4, 96, 96), frame_stack_num=4, - gray_scale=True, + + # observation_shape=(4, 64, 64), + # gray_scale=True, + observation_shape=(12, 64, 64), + gray_scale=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -53,22 +71,37 @@ def main(env_id, seed): # TODO: debug # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), + # only for breakout + # collect_max_episode_steps=int(2e4), + # eval_max_episode_steps=int(2e4), ), policy=dict( - learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=100000, ), ), ), # 100k + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000, ), ), ), # 100k analysis_sim_norm=False, cal_dormant_ratio=False, model=dict( - observation_shape=(4, 96, 96), - image_channel=1, + # observation_shape=(4, 64, 64), + # image_channel=1, + # gray_scale=True, + + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, frame_stack_num=4, - gray_scale=True, + + + # num_res_blocks=1, + # num_channels=64, + num_res_blocks=2, + num_channels=128, + action_space_size=action_space_size, downsample=True, self_supervised_learning_loss=True, # default is False discrete_action_encoding_type='one_hot', norm_type='BN', - use_sim_norm=True, # NOTE use_sim_norm_kl_loss=False, model_type='conv' ), @@ -79,17 +112,24 @@ def main(env_id, seed): game_segment_length=game_segment_length, random_collect_episode_num=0, use_augmentation=True, - use_priority=False, + # use_augmentation=False, + # use_priority=False, + use_priority=True, # TODO(pu): test + priority_prob_alpha=1, + priority_prob_beta=1, replay_ratio=replay_ratio, update_per_collect=update_per_collect, batch_size=batch_size, optim_type='SGD', + policy_entropy_weight=5e-3, td_steps=5, piecewise_decay_lr_scheduler=True, manual_temperature_decay=False, learning_rate=0.2, target_update_freq=100, - num_simulations=num_simulations, + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, ssl_loss_weight=2, eval_freq=int(5e3), replay_buffer_size=int(1e6), @@ -123,8 +163,14 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_muzero_segment - main_config.exp_name = f'data_muzero/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' - train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) + main_config.exp_name = f'data_muzero_20250910_debug/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_csim{collect_num_simulations}-esim{eval_num_simulations}_rgb_seed{seed}' + + # main_config.exp_name = f'data_muzero_20250805/{env_id[:-14]}/{env_id[:-14]}_mz_no-per_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_csim{collect_num_simulations}-esim{eval_num_simulations}_rgb_seed{seed}' + # train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) + + from lzero.entry import train_muzero_segment_save_buffer + train_muzero_segment_save_buffer([main_config, create_config], seed=seed, max_env_step=max_env_step) + if __name__ == "__main__": import argparse @@ -133,4 +179,16 @@ def main(env_id, seed): parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() - main(args.env, args.seed) \ No newline at end of file + args.env = 'MsPacmanNoFrameskip-v4' + # args.env = 'QbertNoFrameskip-v4' + # args.env = 'SeaquestNoFrameskip-v4' + # args.env = 'BreakoutNoFrameskip-v4' + + args.seed = 0 + main(args.env, args.seed) + + """ + export CUDA_VISIBLE_DEVICES=1 + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_muzero_segment_config.py + """ diff --git a/zoo/atari/config/atari_muzero_segment_gray_config.py b/zoo/atari/config/atari_muzero_segment_gray_config.py new file mode 100644 index 000000000..806c9be38 --- /dev/null +++ b/zoo/atari/config/atari_muzero_segment_gray_config.py @@ -0,0 +1,167 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + num_segments = 8 + game_segment_length = 20 + # game_segment_length = 400 + + evaluator_env_num = 3 + # num_simulations = 50 + + collect_num_simulations = 25 + eval_num_simulations = 50 + + update_per_collect = None + # replay_ratio = 0.25 + replay_ratio = 0.1 # 50M envsteps, 5M train iter + # replay_ratio = 0.02 # 50M envsteps, 1M train iter + + num_unroll_steps = 5 + batch_size = 256 + # batch_size = 1024 # orig 256 + + # max_env_step = int(5e5) + max_env_step = int(50e6) + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1 + buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=0.75 + + # =========== for debug =========== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # update_per_collect = 2 + # batch_size = 5 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_muzero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + frame_stack_num=4, + + observation_shape=(4, 64, 64), + gray_scale=True, + # observation_shape=(12, 64, 64), + # gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + analysis_sim_norm=False, + cal_dormant_ratio=False, + model=dict( + observation_shape=(4, 64, 64), + image_channel=1, + gray_scale=True, + + # observation_shape=(12, 64, 64), + # image_channel=3, + # gray_scale=False, + + # num_res_blocks=1, + # num_channels=64, + num_res_blocks=2, + num_channels=128, + + frame_stack_num=4, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + # use_sim_norm=True, # NOTE + use_sim_norm_kl_loss=False, + model_type='conv' + ), + cuda=True, + env_type='not_board_games', + num_segments=num_segments, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + random_collect_episode_num=0, + use_augmentation=True, + # use_priority=False, + use_priority=True, # TODO(pu): test + priority_prob_alpha=1, + priority_prob_beta=1, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + policy_entropy_weight=5e-3, + td_steps=5, + piecewise_decay_lr_scheduler=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_muzero_config = EasyDict(atari_muzero_config) + main_config = atari_muzero_config + + atari_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) + atari_muzero_create_config = EasyDict(atari_muzero_create_config) + create_config = atari_muzero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_muzero_segment + main_config.exp_name = f'data_muzero_20250731/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_csim{collect_num_simulations}-esim{eval_num_simulations}_gray_seed{seed}' + train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + args.env = 'MsPacmanNoFrameskip-v4' + main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_segment_stack1_config.py b/zoo/atari/config/atari_muzero_segment_stack1_config.py new file mode 100644 index 000000000..f6ce307d5 --- /dev/null +++ b/zoo/atari/config/atari_muzero_segment_stack1_config.py @@ -0,0 +1,224 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + num_segments = 8 + + # collector_env_num = 3 + # num_segments = 3 + + game_segment_length = 20 + evaluator_env_num = 3 + # num_simulations = 50 + + # collect_num_simulations = 25 + + collect_num_simulations = 50 + + eval_num_simulations = 50 + + update_per_collect = None + # replay_ratio = 0.25 + replay_ratio = 0.1 # 50M envsteps, 5M train iter + # replay_ratio = 0.02 # 50M envsteps, 1M train iter + + # num_unroll_steps = 5 + num_unroll_steps = 10 # TODO + + batch_size = 256 + # batch_size = 1024 # orig 256 + + # max_env_step = int(5e5) + # max_env_step = int(50e6) + # max_env_step = int(5e6) + + max_env_step = int(1e5) + + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1 + # buffer_reanalyze_freq = 1/2 + # buffer_reanalyze_freq = 1/10 + # buffer_reanalyze_freq = 1/50 + buffer_reanalyze_freq = 1/10000000000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=0.75 + + # =========== for debug =========== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # update_per_collect = 2 + # batch_size = 5 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_muzero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + + # frame_stack_num=4, + # # observation_shape=(4, 64, 64), + # # gray_scale=True, + # frame_stack_num=4, + # observation_shape=(12, 64, 64), + # gray_scale=False, + + frame_stack_num=1, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + # only for breakout + # collect_max_episode_steps=int(2e4), + # eval_max_episode_steps=int(2e4), + ), + policy=dict( + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=100000, ), ), ), # 100k + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000, ), ), ), # 100k + analysis_sim_norm=False, + cal_dormant_ratio=False, + model=dict( + # observation_shape=(4, 64, 64), + # image_channel=1, + # gray_scale=True, + + # observation_shape=(12, 64, 64), + # image_channel=3, + # gray_scale=False, + # frame_stack_num=4, + + frame_stack_num=1, + observation_shape=(3, 64, 64), + gray_scale=False, + image_channel=3, + + + # num_res_blocks=1, + # num_channels=64, + num_res_blocks=2, + num_channels=128, + + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + use_sim_norm_kl_loss=False, + model_type='conv' + ), + # model_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_muzero_20250910_save_buffer/MsPacman/MsPacman_mz_brf0.02-rbs160-rp0.75_numsegments-8_gsl20_rr0.1_Htrain10_bs256_csim25-esim50_rgb_seed0/ckpt/iteration_100000.pth.tar", + model_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_muzero_20250910_save_buffer/MsPacman/MsPacman_mz_brf0.02-rbs160-rp0.75_numsegments-8_gsl20_rr0.1_Htrain10_bs256_csim25-esim50_rgb_seed0/ckpt/ckpt_best.pth.tar", + + cuda=True, + env_type='not_board_games', + num_segments=num_segments, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + num_unroll_steps=num_unroll_steps, # TODO + random_collect_episode_num=0, + use_augmentation=True, + # use_augmentation=False, + # use_priority=False, + use_priority=True, # TODO(pu): test + priority_prob_alpha=1, + priority_prob_beta=1, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + policy_entropy_weight=5e-3, + td_steps=5, + piecewise_decay_lr_scheduler=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + ssl_loss_weight=2, + eval_freq=int(5e3), + # replay_buffer_size=int(1e6), + + replay_buffer_size=int(5e5), + + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_muzero_config = EasyDict(atari_muzero_config) + main_config = atari_muzero_config + + atari_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) + atari_muzero_create_config = EasyDict(atari_muzero_create_config) + create_config = atari_muzero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_muzero_segment + main_config.exp_name = f'data_muzero_20250917_save_buffer/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_csim{collect_num_simulations}-esim{eval_num_simulations}_rgb_seed{seed}' + + # main_config.exp_name = f'data_muzero_20250805/{env_id[:-14]}/{env_id[:-14]}_mz_no-per_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_csim{collect_num_simulations}-esim{eval_num_simulations}_rgb_seed{seed}' + # train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) + + # from lzero.entry import train_muzero_segment_save_buffer + # train_muzero_segment_save_buffer([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + + from lzero.entry import train_muzero_segment_save_buffer_from_ckpt + train_muzero_segment_save_buffer_from_ckpt([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + args.env = 'MsPacmanNoFrameskip-v4' + # args.env = 'QbertNoFrameskip-v4' + # args.env = 'SeaquestNoFrameskip-v4' + # args.env = 'BreakoutNoFrameskip-v4' + + args.seed = 0 + main(args.env, args.seed) + + """ + export CUDA_VISIBLE_DEVICES=1 + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_muzero_segment_stack1_config.py + """ diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 2c68c80fe..02a49bbe6 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -10,27 +10,61 @@ def main(env_id='PongNoFrameskip-v4', seed=0): # begin of the most frequently changed config specified by the user # ============================================================== collector_env_num = 8 - game_segment_length = 20 + n_episode = 8 evaluator_env_num = 3 + + collector_env_num = 1 + n_episode = 1 + evaluator_env_num = 1 + num_simulations = 50 - max_env_step = int(5e5) - batch_size = 64 + collect_num_simulations = 25 + # collect_num_simulations = 50 + eval_num_simulations = 50 + max_env_step = int(5e6) + # max_env_step = int(50e6) + batch_size = 256 + # batch_size = 64 # debug + # batch_size = 4 # debug + + num_layers = 2 + # replay_ratio = 0.25 + replay_ratio = 0.1 + + game_segment_length = 20 num_unroll_steps = 10 infer_context_length = 4 - num_layers = 2 - replay_ratio = 0.25 - # TODO: only for debug + # game_segment_length = 40 + # num_unroll_steps = 20 + # infer_context_length = 8 + + # game_segment_length = 200 + # num_unroll_steps = 16 + # infer_context_length = 8 + + # num_unroll_steps = 4 # TODO + # infer_context_length = 2 + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10 + # buffer_reanalyze_freq = 1/1000000000000 + + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + + # norm_type ="BN" + norm_type ="LN" + + # ====== only for debug ===== # collector_env_num = 2 - # game_segment_length = 20 + # num_segments = 2 # evaluator_env_num = 2 - # num_simulations = 2 - # max_env_step = int(5e5) - # batch_size = 10 - # num_unroll_steps = 5 - # infer_context_length = 2 - # num_layers = 1 - # replay_ratio = 0.1 + # num_simulations = 10 + # batch_size = 5 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -49,12 +83,25 @@ def main(env_id='PongNoFrameskip-v4', seed=0): # eval_max_episode_steps=int(50), ), policy=dict( - learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=100000, ), ), ), # 100k + model=dict( observation_shape=(3, 64, 64), action_space_size=action_space_size, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), world_model_cfg=dict( - policy_entropy_weight=1e-4, + game_segment_length=game_segment_length, + + norm_type=norm_type, + num_res_blocks=2, + num_channels=128, + # num_res_blocks=1, # TODO + # num_channels=64, + support_size=601, + policy_entropy_weight=5e-3, + # policy_entropy_weight=5e-2, # TODO(pu) continuous_action_space=False, max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action @@ -65,20 +112,76 @@ def main(env_id='PongNoFrameskip-v4', seed=0): num_heads=8, embed_dim=768, obs_type='image', + encoder_type="resnet", env_num=max(collector_env_num, evaluator_env_num), + num_simulations=num_simulations, rotary_emb=False, + # rotary_emb=True, + # final_norm_option_in_encoder='LayerNorm_Tanh', + # final_norm_option_in_obs_head="LayerNorm", + # predict_latent_loss_type='mse', + + # final_norm_option_in_encoder='L2Norm', + # final_norm_option_in_obs_head="L2Norm", + # predict_latent_loss_type='mse', + + final_norm_option_in_encoder="LayerNorm", + final_norm_option_in_obs_head="LayerNorm", + predict_latent_loss_type='mse', + + # final_norm_option_in_encoder="SimNorm", + # final_norm_option_in_obs_head="SimNorm", + # predict_latent_loss_type='group_kl', + + # weight_decay=1e-2, + latent_norm_loss=True, + + # latent_norm_loss=False, + weight_decay=1e-4, # TODO + + use_priority=True, # TODO(pu): test ), ), + # gradient_scale=True, #TODO + gradient_scale=False, #TODO + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, + use_augmentation=False, # TODO + + use_priority=True, # TODO(pu): test + priority_prob_alpha=1, + priority_prob_beta=1, + + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(2.5e4), num_unroll_steps=num_unroll_steps, + update_per_collect=None, replay_ratio=replay_ratio, batch_size=batch_size, + optim_type='AdamW', + # target_model_update_option="hard", + target_update_freq=100, + + target_model_update_option="soft", + # target_update_theta=0.005, # TODO + # target_update_theta=0.01, + target_update_theta=0.05, + learning_rate=0.0001, - num_simulations=num_simulations, - train_start_after_envsteps=2000, - # train_start_after_envsteps=0, # TODO: only for debug + # learning_rate=0.0003, # TODO + + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + # num_segments=num_segments, + n_episode=n_episode, + td_steps=5, + train_start_after_envsteps=0, + # train_start_after_envsteps=2000, # TODO game_segment_length=game_segment_length, - replay_buffer_size=int(1e6), + # replay_buffer_size=int(1e6), + replay_buffer_size=int(1e5), # TODO + eval_freq=int(5e3), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -101,7 +204,7 @@ def main(env_id='PongNoFrameskip-v4', seed=0): atari_unizero_create_config = EasyDict(atari_unizero_create_config) create_config = atari_unizero_create_config - main_config.exp_name = f'data_lz/data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_longrun_20250904/{env_id[:-14]}/{env_id[:-14]}_uz_episode_rbs1e5_envnum{collector_env_num}_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' from lzero.entry import train_unizero train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) @@ -113,5 +216,30 @@ def main(env_id='PongNoFrameskip-v4', seed=0): parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() + + + + # args.env = 'PongNoFrameskip-v4' + + args.env = 'MsPacmanNoFrameskip-v4' + + # args.env = 'QbertNoFrameskip-v4' + # args.env = 'SeaquestNoFrameskip-v4' + + # args.env = 'SpaceInvadersNoFrameskip-v4' + # args.env = 'BeamRiderNoFrameskip-v4' + # args.env = 'GravitarNoFrameskip-v4' + + # args.env = 'BreakoutNoFrameskip-v4' + + + args.seed = 0 + + main(args.env, args.seed) + """ + export CUDA_VISIBLE_DEVICES=2 + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_config.py + """ diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py new file mode 100644 index 000000000..88fa46a88 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py @@ -0,0 +1,481 @@ +from easydict import EasyDict + +import math + +def compute_batch_config(env_id_list, effective_batch_size): + n = len(env_id_list) + + # 根据环境数量设定有效 batch size 和每个环境的最大微 batch size + gpu_num = 8 + max_micro_batch_one_gpu = 400 + max_micro_batch = int(max_micro_batch_one_gpu / (n // gpu_num)) + + + # 计算每个环境理论上应该分得的 batch size + theoretical_env_batch = effective_batch_size / n + + if theoretical_env_batch > max_micro_batch: + # 当每个环境按均分的 batch 大于允许的最大微 batch 时, + # 则令每个环境的实际微 batch size 固定为 max_micro_batch + micro_batch_size = max_micro_batch + # 梯度累计步数 = ceil(每个环境理论 batch size / 最大微 batch size) + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch) + else: + # 否则直接使用计算出的理论 batch size(这里向下取整以保证整数) + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # 为每个环境分配相同的微 batch size + batch_size = [micro_batch_size for _ in range(n)] + + # 打印一些调试信息(也可以记录到 log 中) + print("环境数量: {}".format(n)) + print("有效 total batch size: {}".format(effective_batch_size)) + print("每个环境的理论 batch size: {:.2f}".format(theoretical_env_batch)) + print("每个环境的微 batch size: {}".format(micro_batch_size)) + print("梯度累积步数: {}".format(grad_accumulate_steps)) + + return batch_size, grad_accumulate_steps + + + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + # num_channels=512, # ==============TODO============== + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + # share_head=True, # TODO + share_head=False, # TODO + + # analysis_dormant_ratio_weight_rank=True, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=128, + # # task_embed_dim=96, + + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + # num_layers=12, + # num_heads=24, + + num_layers=4, # TODO======= + # num_layers=8, + num_heads=24, + + # ===== only for debug ===== + # num_layers=1, + # num_heads=8, + + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + + encoder_type='vit', + # encoder_type='resnet', + + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + # multiplication_moe_in_transformer=False, + multiplication_moe_in_transformer=True, # TODO======= + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA 参数: + # moe_use_lora=False, # TODO + moe_use_lora=True, # TODO + + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=["attn", "feed_forward"], + lora_r=64, # TODO + lora_alpha=32, + # lora_r=128, # TODO + # lora_alpha=128, + lora_dropout=0.1, + lora_scale_init=1, + + min_stage0_iters=50000, # 50k + max_stage_iters=20000, # 20k + + # ==================== 新增的控制参数 ==================== + # 设置为 False,则课程学习和LoRA冻结将只应用于Transformer Backbone + # 设置为 True 或不设置此项,则同时应用于ViT Encoder和Transformer Backbone + apply_curriculum_to_encoder=False, + ), + ), + use_task_exploitation_weight=False, # TODO + # use_task_exploitation_weight=True, # TODO + target_return =target_return_dict[env_id], + balance_pipeline=True, + # task_complexity_weight=False, # TODO + task_complexity_weight=True, # TODO: 这个选项打开时统计所有环境的norm mean + + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), # TODO: DEBUG + # train_start_after_envsteps=int(2000), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + # update_per_collect=80, # TODO + update_per_collect=40, # TODO + + # update_per_collect=2, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, + # (int) the number of simulations in MCTS for renalyze. + num_simulations=num_simulations, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(1e4), + eval_freq=int(1e4), + # eval_freq=int(2), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + # ===== only for debug ===== + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-encoder-ps8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_no-encoder-scale_cnn-encoder_moe8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250514/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_unizero_atari_mt_balance_20250730/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_stage-50k-20k_vit-small-ln_trans-nlayer4-moe8_attn-mlp-lora_no-lora-scale_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_unizero_atari_mt_balance_20250730/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_stage-50k-20k_vit-small-ln_trans-nlayer4-moe8_encoder-backbone-attn-mlp-lora_no-lora-scale_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + exp_name_prefix = f'data_unizero_atari_mt_balance_20250814/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_stage-50k-20k_vit-small-ln_trans-nlayer4-moe8_backbone-attn-mlp-lora_no-lora-scale_brf{buffer_reanalyze_freq}-rr025_collect25_lora64_not-share-head_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/202507/uz_mt_nlayer4_atari8_balance-totalstage5_backbone_brf002_collect25_lora64-32.log + + cd /cpfs04/user/puyuan/code/LightZero/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /cpfs04/user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250522_cpfs/uz_mt_nlayer4_atari8_vit-small_moe8-lora_balance-totalstage5_stage-50k-20k_s0.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer4_atari26_vit-ln_moe8_balance-totalstage9.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari26_vit-ln_moe8_totalstage5.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer8_atari8_vit-ln_moe8_balance-totalstage5.log + + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari8_no-encoder-grad-scale_cnn-encoder_moe8_totalstage5_20250509.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_cnn-encoder_totalstage9_balance20250505.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari8_vit-base-encoder-ps8_totalstage3_balance_20250501_debug.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_vit-large-encoder-ps8-simnorm_totalstage5_balance20250501.log + + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # ] + # # List of Atari games used for multi-task learning + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + # 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + # 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + # 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + # 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + # ] + + def get_atari_target_return_dict(ratio=1.0): + """ + 根据 Human 分数和传入的比例参数 ratio 计算每个 Atari 游戏的 target_return。 + + 参数: + ratio: 控制 target_return 大小的比例因子,默认为 1.0 + + 返回: + 包含 Atari 游戏 target_return 的字典,key 为环境名称,value 为计算后的目标分数(整数)。 + """ + human_scores = { + # 8games + 'PongNoFrameskip-v4': 14.6, # 0 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 + 'BoxingNoFrameskip-v4': 12.1, # 3 + 'AlienNoFrameskip-v4': 7127.7, # 4 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 + 'HeroNoFrameskip-v4': 30826.4, # 6 + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 1719.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 8503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 37187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 35829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 4334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 22736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 69571.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + } + + # target score + target_scores = { + # 8games + # 'PongNoFrameskip-v4': 14.6, # 0 expert + 'PongNoFrameskip-v4': 20, # 0 expert + # 'MsPacmanNoFrameskip-v4': 1500.6, # 1 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + # 'SeaquestNoFrameskip-v4': 1000.7, # 2 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 expert + 'BoxingNoFrameskip-v4': 12.1, # 3 expert + # 'AlienNoFrameskip-v4': 1000.7, # 4 + 'AlienNoFrameskip-v4': 7127.7, # 4 expert + # 'ChopperCommandNoFrameskip-v4': 3000.8, # 5 + # 'HeroNoFrameskip-v4': 3082.4, # 6 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 expert + 'HeroNoFrameskip-v4': 30826.4, # 6 expert + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 expert + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 100.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 1503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 12187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 15829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 12736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 1001.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + } + + + # 计算每个游戏的 target_return + # return {env: int(round(score * ratio)) for env, score in human_scores.items()} + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + + global target_return_dict + # global BENCHMARK_NAME + # BENCHMARK_NAME='atari' + + # 示例:以 ratio=1 使用 + target_return_dict = get_atari_target_return_dict(ratio=1) + # target_return_dict = get_atari_target_return_dict(ratio=0.5) + num_games = 8 # 26 # 8 + + # 分别定义 Atari 游戏列表(8games 和 26games) + if num_games==3: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + ] + elif num_games==8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games==26: + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + global curriculum_stage_num + # TODO ============== + # curriculum_stage_num=3 + curriculum_stage_num=5 + # curriculum_stage_num=9 + + action_space_size = 18 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + + if len(env_id_list) == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + # effective_batch_size = 832 # cnn-encoder + effective_batch_size = 512 # base-vit-encoder + # effective_batch_size = 256 # base-vit-encoder large-vit-encoder + elif len(env_id_list) == 18: + effective_batch_size = 512 * 3 # 1536 + else: + raise ValueError("不支持的环境数量: {}".format(n)) + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + total_batch_size = effective_batch_size # 当前无效 + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + # buffer_reanalyze_freq = 1 / 10 + buffer_reanalyze_freq = 1 / 2 + + + # buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # reanalyze_batch_size = 2 + # num_unroll_steps = 5 + # infer_context_length = 2 + # batch_sizes = [4 for _ in range(len(env_id_list))] + + from lzero.entry import train_unizero_multitask_balance_segment_ddp + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + + with DDPContext(): + train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="atari") + # ======== TODO: only for debug ======== + # train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_segment_config.py b/zoo/atari/config/atari_unizero_segment_config.py index dec2ee4d2..a58023492 100644 --- a/zoo/atari/config/atari_unizero_segment_config.py +++ b/zoo/atari/config/atari_unizero_segment_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map - +# # 在main文件开始,通过全局变量来控制是否处于调试状态 +# global DEBUG_ENABLED;DEBUG_ENABLED = True def main(env_id, seed): action_space_size = atari_env_action_space_map[env_id] @@ -10,23 +11,59 @@ def main(env_id, seed): # ============================================================== collector_env_num = 8 num_segments = 8 - game_segment_length = 20 - evaluator_env_num = 10 + evaluator_env_num = 3 + + # collector_env_num = 1 + # num_segments = 1 + # evaluator_env_num = 1 + num_simulations = 50 - max_env_step = int(5e5) - batch_size = 64 + collect_num_simulations = 25 + # collect_num_simulations = 50 + eval_num_simulations = 50 + max_env_step = int(10e6) #TODO mspacman======== + # max_env_step = int(2e6)#TODO======== + + # max_env_step = int(50e6) + batch_size = 256 + + # batch_size = 64 # encoder_type="dinov2", #TODO======== + + # batch_size = 16 # debug + # batch_size = 4 # debug + num_layers = 2 - replay_ratio = 0.25 + # replay_ratio = 0.25 + replay_ratio = 0.1 + + game_segment_length = 20 num_unroll_steps = 10 infer_context_length = 4 + # game_segment_length = 40 + # num_unroll_steps = 20 + # infer_context_length = 8 + + # game_segment_length = 200 + # num_unroll_steps = 16 + # infer_context_length = 8 + + # num_unroll_steps = 4 # TODO + # infer_context_length = 2 + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. - buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/1000000000000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. reanalyze_partition = 0.75 + # norm_type ="BN" + norm_type ="LN" + # ====== only for debug ===== # collector_env_num = 2 # num_segments = 2 @@ -42,25 +79,47 @@ def main(env_id, seed): stop_value=int(1e6), env_id=env_id, observation_shape=(3, 64, 64), + # observation_shape=(3, 96, 96), + gray_scale=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + collect_max_episode_steps=int(10000), + eval_max_episode_steps=int(10000), # TODO: only for debug # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), ), policy=dict( - learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=100000, ), ), ), # 100k + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000, ), ), ), # 50k + # sample_type='episode', # NOTE: very important for memory env model=dict( observation_shape=(3, 64, 64), + # observation_shape=(3, 96, 96), action_space_size=action_space_size, reward_support_range=(-300., 301., 1.), value_support_range=(-300., 301., 1.), + norm_type=norm_type, + num_res_blocks=2, + num_channels=128, + # num_res_blocks=1, # TODO + # num_channels=64, world_model_cfg=dict( + game_segment_length=game_segment_length, + + encoder_type="resnet", #TODO======== + # encoder_type="dinov2", #TODO======== + + norm_type=norm_type, + support_size=601, policy_entropy_weight=5e-3, + # policy_entropy_weight=5e-2, # TODO(pu) mspacman + continuous_action_space=False, max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action @@ -74,28 +133,226 @@ def main(env_id, seed): env_num=max(collector_env_num, evaluator_env_num), num_simulations=num_simulations, rotary_emb=False, + # rotary_emb=True, + # final_norm_option_in_encoder='LayerNorm_Tanh', + # final_norm_option_in_obs_head="LayerNorm", + # predict_latent_loss_type='mse', + + # final_norm_option_in_encoder='L2Norm', + # final_norm_option_in_obs_head="L2Norm", + # predict_latent_loss_type='mse', + + final_norm_option_in_encoder="LayerNorm", + final_norm_option_in_obs_head="LayerNorm", + predict_latent_loss_type='mse', + + # final_norm_option_in_encoder="SimNorm", + # final_norm_option_in_obs_head="SimNorm", + # predict_latent_loss_type='group_kl', + + # weight_decay=1e-2, + + # latent_norm_loss=True, + latent_norm_loss=False, + + # optim_type='AdamW_mix_lr_wdecay', + # # optim_type='AdamW', + # # weight_decay=1e-4, # TODO orig + # weight_decay=1e-3, # TODO: encoder 5*wd + + + use_priority=True, # TODO(pu): test + + # optim_type='SGD', + # learning_rate=0.01, + + # learning_rate=0.001, + learning_rate=0.0001, + + # entry_norm=True, # TODO======== + entry_norm=False, # TODO======== + + # use_temperature_scaling=True, + + use_temperature_scaling=False, # TODO======== + + # res_alha=True, + res_alha=False, # TODO======== + + + optim_type='AdamW_mix_lr_wdecay', # only for tsne plot + + # optim_type='AdamW_mix_lr', + # learning_rate=0.001, + ), ), + + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + # eps_greedy_exploration_in_collect=True, + eps_greedy_exploration_in_collect=False, + + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + # decay=int(1e5), # 100k=1e5 + decay=int(2e4), # 20k=2e4 envsteps + # decay=int(2e5), # 200k=2e5 envsteps + + ), + + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + target_entropy_start_ratio =0.98, + # target_entropy_end_ratio =0.9, + # target_entropy_end_ratio =0.7, + # target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 + target_entropy_end_ratio =0.5, # TODO===== + target_entropy_decay_steps = 400000, # 例如,在400k次迭代后达到最终值 + + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在100k次迭代后达到最终值 + # encoder_clip_anneal_steps=400000, # 例如,在100k次迭代后达到最终值 + + + # policy_ls_eps_start=0.5, #TODO============= + # policy_ls_eps_start=0.1, #TODO============= + policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + # policy_ls_eps_start=1, #TODO=========== + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + + label_smoothing_eps=0.1, #TODO============= + + # label_smoothing_eps=0., + # policy_ls_eps_start=0.0, #TODO============= + + + # gradient_scale=True, #TODO + gradient_scale=False, #TODO # (str) The path of the pretrained model. If None, the model will be initialized by the default model. - model_path=None, - use_augmentation=False, + model_path=None, # TODO======= + # model_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_unizero_longrun_20250923/Qbert/Qbert_uz_targetentropy-alpha-098-07-100k-fix_encoder-clip30-10-100k_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_envnum8_brf1e-12-rbs160-rp0.75_nlayer2_numsegments-8_gsl20_rr0.1_Htrain10-Hinfer4_bs256_c25_seed0/ckpt/ckpt_best.pth.tar", + # model_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_unizero_longrun_20250923/Seaquest/Seaquest_uz_targetentropy-alpha-098-07-100k-fix_encoder-clip30-10-100k_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_envnum8_brf1e-12-rbs160-rp0.75_nlayer2_numsegments-8_gsl20_rr0.1_Htrain10-Hinfer4_bs256_c25_seed0/ckpt/ckpt_best.pth.tar", + use_augmentation=False, # TODO + # use_augmentation=True, # TODO======= + + + use_priority=True, # TODO(pu): test + priority_prob_alpha=1, + priority_prob_beta=1, + + # manual_temperature_decay=True, + # threshold_training_steps_for_final_temperature=int(5e4), # 50k iter 对应 500k envsteps + manual_temperature_decay=False, threshold_training_steps_for_final_temperature=int(2.5e4), - use_priority=False, + num_unroll_steps=num_unroll_steps, update_per_collect=None, replay_ratio=replay_ratio, batch_size=batch_size, - optim_type='AdamW', - learning_rate=0.0001, - num_simulations=num_simulations, + + piecewise_decay_lr_scheduler=False, + + cos_lr_scheduler=False, # TODO======== + # cos_lr_scheduler=True, # TODO======== + total_iterations=500000, + final_learning_rate=1e-6, + + optim_type='AdamW_mix_lr_wdecay', + # optim_type='AdamW', + # weight_decay=1e-4, # TODO orig + # weight_decay=1e-3, # TODO: encoder 5*wd + weight_decay=1e-2, # TODO: encoder 5*wd + + + # optim_type='AdamW', + # # learning_rate=0.001, + # learning_rate=0.0001, + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=10000, + # monitor_norm_freq=2, + + # ============================================================ + + # latent_norm_clip_threshold=3, # 768dim + # latent_norm_clip_threshold=5, # 768dim latent encoder + latent_norm_clip_threshold=30, # 768dim latent encoder # for pong + + # latent_norm_clip_threshold=10, # 768dim latent encoder + + # latent_norm_clip_threshold=25, # 768dim latent encoder + + # latent_norm_clip_threshold=20, # 768dim + # latent_norm_clip_threshold=30, # 768dim + + # logit_clip_threshold=5, # value reward + # policy_logit_clip_threshold=1, # policy + + logit_clip_threshold=9999, # value reward + policy_logit_clip_threshold=99999, # policy + + # piecewise_decay_lr_scheduler=False, + # optim_type='AdamW_mix_lr', + # learning_rate=0.001, + + + # optim_type='SGD', # TODO + # piecewise_decay_lr_scheduler=True, + # # learning_rate=0.2, + # learning_rate=0.01, + + + # target_model_update_option="hard", + target_update_freq=100, + + target_model_update_option="soft", + # target_update_theta=0.005, # TODO + # target_update_theta=0.01, + target_update_theta=0.05, + + + + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, num_segments=num_segments, td_steps=5, train_start_after_envsteps=0, + # train_start_after_envsteps=2000, # TODO game_segment_length=game_segment_length, grad_clip_value=5, - replay_buffer_size=int(1e6), - eval_freq=int(5e3), + + backbone_grad_clip_value=5, + # head_grad_clip_value=0.5, + head_grad_clip_value=5, # TODO + + # replay_buffer_size=int(1e6), + replay_buffer_size=int(5e5), # TODO + + # eval_freq=int(5e3), + eval_freq=int(1e4), # TODO + # eval_freq=int(2e4), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, # ============= The key different params for reanalyze ============= @@ -125,8 +382,36 @@ def main(env_id, seed): create_config = atari_unizero_create_config # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_unizero_segment - main_config.exp_name = f'data_lz/data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_longrun_20251010/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-098-05-400k-fix_encoder-clip30-10-100k_adamw1e-4_wd1e-2-encoder1-tran1-head1_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20251010/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-098-05-400k-fix_encoder-clip30-10-400k_adamw1e-4_wd1e-2-encoder1-tran1-head1_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250923/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-098-05-400k-fix_encoder-clip30-10-400k_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250923/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-098-07-100k-fix_encoder-clip30-10-100k_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250923/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-098-07-100k-fix_encoder-clip30-10-100k_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250923/{env_id[:-14]}/{env_id[:-14]}_uz_2mckpt_targetentropy-alpha-098-07-100k-fix_encoder-clip30-10-100k_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250923/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-098-09-100k-fix_encoder-clip30-10-100k_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250923/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-200k-1-07_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_encoder-clip10_label-smooth-valuereward01-policy-005_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250922/{env_id[:-14]}/{env_id[:-14]}_uz_pew0005_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_encoder-clip10_label-smooth-valuereward01-policy-005_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250922/{env_id[:-14]}/{env_id[:-14]}_uz_encoder-01lr-tran02lr_pew0005_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_encoder-clip10_label-smooth-valuereward01-policy-005_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + # main_config.exp_name = f'data_unizero_longrun_20250922/{env_id[:-14]}/{env_id[:-14]}_uz_lw-v1_pew0005_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_encoder-clip10_label-smooth-valuereward01-policy-005_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250922/{env_id[:-14]}/{env_id[:-14]}_uz_pew005_adamw1e-4_cosdecay500k-1e-6_wd1e-2-encoder5times-tranwd-headnodecay_encoder-clip10_label-smooth-valuereward01-policy-005_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250922/{env_id[:-14]}/{env_id[:-14]}_uz_eps20k_pew0005_adamw1e-4_wd1e-2-encoder5times-tranwd-headnodecay_encoder-clip10_label-smooth-valuereward01-policy-005_envnum{collector_env_num}_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + + + train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) @@ -137,4 +422,71 @@ def main(env_id, seed): parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() + # 测试的atari8中的4个base环境 + # args.env = 'PongNoFrameskip-v4' # 反应型环境 密集奖励 + # args.env = 'MsPacmanNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + # args.env = 'SeaquestNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + # args.env = 'HeroNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + # 下面是atari8以外的2个代表环境 + args.env = 'QbertNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + # args.env = 'SpaceInvadersNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + # 下面是已经表现不错的 + # args.env = 'BoxingNoFrameskip-v4' # 反应型环境 密集奖励 + # args.env = 'ChopperCommandNoFrameskip-v4' + + # args.env = 'AlienNoFrameskip-v4' + # args.env = 'RoadRunnerNoFrameskip-v4' + + # args.env = 'BeamRiderNoFrameskip-v4' + # args.env = 'GravitarNoFrameskip-v4' + # args.env = 'BreakoutNoFrameskip-v4' + + + args.seed = 0 + + main(args.env, args.seed) + + """ + export CUDA_VISIBLE_DEVICES=0 + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/wd1e-2-encoder1-trans1-head1_targetentropy-alpha-400k-098-05-encoder-clip30-10-100k-qbert.log 2>&1 + + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoderdinov2_mspac.log 2>&1 + + + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_clip-encoder5-value5-policy1_fix-clip_msp.log 2>&1 + + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoder-clip-30_fix-clip_resalpha_pong.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoder-clip-30_fix-clip_reinit-value-reward-policy-50k_pong.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoder-clip-5_fix-clip_temp-scale-softplus-fixcollecteval_reinit-value-reward-policy-50k_pong.log 2>&1 + + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoder-clip-5-true_head-clip10-pol5_fix-clip_mz-head_reinit-value-reward-policy-50k_msp.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoder-clip-5_head-clip10-pol5_fix-clip_reinit-value-reward-policy-50k.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoder-clip-5_entry-norm_clipgrad-backbone5-head5_reinit-value-reward-policy-50k_head-clip10-pol05_msp.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_64_encoder-clip-5_entry-norm_clipgrad-backbone5-head05_grad-scale_reinit-value-reward-policy-50k_head-clip10-pol05_pong.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_96.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_sgd_02-0002.log 2>&1 + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_sgd_001.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-3.log 2>&1 + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw1e-4_96.log 2>&1 + + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py > /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/logs/unizero_adamw-mix-1e-3.log 2>&1 + + + """ diff --git a/zoo/atari/config/atari_unizero_segment_from_buffer_config.py b/zoo/atari/config/atari_unizero_segment_from_buffer_config.py new file mode 100644 index 000000000..94032955d --- /dev/null +++ b/zoo/atari/config/atari_unizero_segment_from_buffer_config.py @@ -0,0 +1,314 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map +# # 在main文件开始,通过全局变量来控制是否处于调试状态 +# global DEBUG_ENABLED;DEBUG_ENABLED = True + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + num_segments = 8 + evaluator_env_num = 3 + + # collector_env_num = 1 + # num_segments = 1 + # evaluator_env_num = 1 + + num_simulations = 50 + collect_num_simulations = 25 + # collect_num_simulations = 50 + eval_num_simulations = 50 + max_env_step = int(5e6) + # max_env_step = int(50e6) + batch_size = 256 + # batch_size = 64 # debug + # batch_size = 4 # debug + + num_layers = 2 + # replay_ratio = 0.25 + replay_ratio = 0.1 + + game_segment_length = 20 + num_unroll_steps = 10 + infer_context_length = 4 + + # game_segment_length = 40 + # num_unroll_steps = 20 + # infer_context_length = 8 + + # game_segment_length = 200 + # num_unroll_steps = 16 + # infer_context_length = 8 + + # num_unroll_steps = 4 # TODO + # infer_context_length = 2 + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/1000000000000 + + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + + # norm_type ="BN" + norm_type ="LN" + + # ====== only for debug ===== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 10 + # batch_size = 5 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=100000, ), ), ), # 100k + # sample_type='episode', # NOTE: very important for memory env + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + norm_type=norm_type, + + world_model_cfg=dict( + encoder_type="resnet", #TODO======== + game_segment_length=game_segment_length, + + norm_type=norm_type, + num_res_blocks=2, + num_channels=128, + # num_res_blocks=1, # TODO + # num_channels=64, + support_size=601, + policy_entropy_weight=5e-3, + # policy_entropy_weight=5e-2, # TODO(pu) + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + num_simulations=num_simulations, + rotary_emb=False, + # rotary_emb=True, + # final_norm_option_in_encoder='LayerNorm_Tanh', + # final_norm_option_in_obs_head="LayerNorm", + # predict_latent_loss_type='mse', + + # final_norm_option_in_encoder='L2Norm', + # final_norm_option_in_obs_head="L2Norm", + # predict_latent_loss_type='mse', + + final_norm_option_in_encoder="LayerNorm", + final_norm_option_in_obs_head="LayerNorm", + predict_latent_loss_type='mse', + + # final_norm_option_in_encoder="SimNorm", + # final_norm_option_in_obs_head="SimNorm", + # predict_latent_loss_type='group_kl', + + # weight_decay=1e-2, + # latent_norm_loss=True, + latent_norm_loss=False, + + + # latent_norm_loss=False, + weight_decay=1e-4, # TODO + + use_priority=True, # TODO(pu): test + # entry_norm=True, # TODO======== + entry_norm=False, # TODO======== + use_temperature_scaling=False, # TODO======== + res_alha=False, # TODO======== + + ), + ), + # policy_ls_eps_start=0.5, #TODO============= + # policy_ls_eps_start=0.1, #TODO============= + # policy_ls_eps_start=0.0, #TODO============= + policy_ls_eps_start=0.05, #TODO============= + + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + + label_smoothing_eps=0.1, #TODO============= + # label_smoothing_eps=0., + + # gradient_scale=True, #TODO + gradient_scale=False, #TODO + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + use_augmentation=False, # TODO + + use_priority=True, # TODO(pu): test + priority_prob_alpha=1, + priority_prob_beta=1, + + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(2.5e4), + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + # target_model_update_option="hard", + target_update_freq=100, + + target_model_update_option="soft", + # target_update_theta=0.005, # TODO + # target_update_theta=0.01, + target_update_theta=0.05, + learning_rate=0.0001, + + # # learning_rate=0.0003, # TODO + # # latent_norm_clip_threshold=3, # 768dim + latent_norm_clip_threshold=30, # 768dim latent + # logit_clip_threshold=10, # value reward + # # policy_logit_clip_threshold=0.5, # policy + # policy_logit_clip_threshold=5, # policy + + + # latent_norm_clip_threshold=999, # 768dim latent + logit_clip_threshold=999, # value reward + policy_logit_clip_threshold=999, # policy + + # piecewise_decay_lr_scheduler=False, + # optim_type='AdamW_mix_lr', + # learning_rate=0.001, + + backbone_grad_clip_value=5, + # head_grad_clip_value=0.5, + head_grad_clip_value=5, # TODO + + # replay_buffer_size=int(1e6), + replay_buffer_size=int(5e5), # TODO + + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + num_segments=num_segments, + td_steps=5, + train_start_after_envsteps=0, + # train_start_after_envsteps=2000, # TODO + game_segment_length=game_segment_length, + grad_clip_value=5, + + + # eval_freq=int(5e3), + eval_freq=int(2e3), # TODO + + # eval_freq=int(1e4), # TODO + # eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_unizero_config = EasyDict(atari_unizero_config) + main_config = atari_unizero_config + + atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_create_config = EasyDict(atari_unizero_create_config) + create_config = atari_unizero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_unizero_segment_from_buffer + main_config.exp_name = f'data_unizero_longrun_from_buffer_20250917/{env_id[:-14]}/{env_id[:-14]}_uz_orighead_label-smooth-valuereward01-policy-005_encoder-clip30_clear{game_segment_length}_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_from_buffer_20250917/{env_id[:-14]}/{env_id[:-14]}_uz_muzerohead_noclip_clear{game_segment_length}_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_from_buffer_20250917/{env_id[:-14]}/{env_id[:-14]}_uz_in-value-reward-head-ln2_per_lnlw1e-4_enc-LN_fix-init-recur_encoder-head-ln_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + # main_config.exp_name = f'data_unizero_longrun_20250901_debug/{env_id[:-14]}/{env_id[:-14]}_uz_per_lnlw1e-4_enc-BN_fix-init-recur_encoder-head-ln_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250901/{env_id[:-14]}/{env_id[:-14]}_uz_per_grad-scale_pew005_fix-init-recur_encoder-head-ln_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_per_lnlw0001_pew005_fix-init-recur_encoder-head-ln_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_current-next-latent_fix-init-recur_encoder-head-simnorm_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_current-next-latent_latent-norm-weight001_fix-init-recur_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c50_seed{seed}' + + # expert_buffer_path="data_muzero_20250910/MsPacman/MsPacman_mz_brf0.02-rbs160-rp0.75_numsegments-8_gsl20_rr0.1_Htrain10_bs256_csim25-esim50_rgb_seed0_250910_154414/game_buffers/muzero_game_buffer_iter_32.pth" + # expert_buffer_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_muzero_20250910_save_buffer/MsPacman/MsPacman_mz_brf0.02-rbs160-rp0.75_numsegments-8_gsl20_rr0.1_Htrain10_bs256_csim25-esim50_rgb_seed0/game_buffers/muzero_game_buffer_iter_100000.pth" + # expert_buffer_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_muzero_20250917_save_buffer/MsPacman/MsPacman_mz_brf1e-10-rbs160-rp0.75_numsegments-8_gsl20_rr0.1_Htrain10_bs256_csim25-esim50_rgb_seed0_250916_171727/game_buffers/muzero_game_buffer_iter_120.pth" + # expert_buffer_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_muzero_20250917_save_buffer/MsPacman/MsPacman_mz_brf1e-10-rbs160-rp0.75_numsegments-8_gsl20_rr0.1_Htrain10_bs256_csim50-esim50_rgb_seed0/game_buffers/muzero_game_buffer_iter_16.pth" + expert_buffer_path="/mnt/nfs/zhangjinouwen/puyuan/LightZero/data_muzero_20250917_save_buffer/MsPacman/MsPacman_mz_brf1e-10-rbs160-rp0.75_numsegments-8_gsl20_rr0.1_Htrain10_bs256_csim50-esim50_rgb_seed0_250916_193157/game_buffers/muzero_game_buffer_iter_6845.pth" + train_unizero_segment_from_buffer([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step, expert_buffer_path=expert_buffer_path) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + # args.env = 'PongNoFrameskip-v4' + + args.env = 'MsPacmanNoFrameskip-v4' + + # args.env = 'QbertNoFrameskip-v4' + # args.env = 'SeaquestNoFrameskip-v4' + + # args.env = 'SpaceInvadersNoFrameskip-v4' + # args.env = 'BeamRiderNoFrameskip-v4' + # args.env = 'GravitarNoFrameskip-v4' + + # args.env = 'BreakoutNoFrameskip-v4' + + + args.seed = 0 + + + main(args.env, args.seed) + + """ + export CUDA_VISIBLE_DEVICES=3 + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_from_buffer_config.py + """ diff --git a/zoo/atari/config/atari_unizero_segment_stack4_config.py b/zoo/atari/config/atari_unizero_segment_stack4_config.py new file mode 100644 index 000000000..73052dca6 --- /dev/null +++ b/zoo/atari/config/atari_unizero_segment_stack4_config.py @@ -0,0 +1,281 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map +# # 在main文件开始,通过全局变量来控制是否处于调试状态 +# global DEBUG_ENABLED;DEBUG_ENABLED = True + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + num_segments = 8 + evaluator_env_num = 3 + + # collector_env_num = 1 + # num_segments = 1 + # evaluator_env_num = 1 + + num_simulations = 50 + collect_num_simulations = 25 + # collect_num_simulations = 50 + eval_num_simulations = 50 + # max_env_step = int(5e5) + max_env_step = int(50e6) + batch_size = 256 + # batch_size = 64 # debug + # batch_size = 4 # debug + + num_layers = 2 + # replay_ratio = 0.25 + replay_ratio = 0.1 + + game_segment_length = 20 + num_unroll_steps = 10 + infer_context_length = 4 + + # game_segment_length = 40 + # num_unroll_steps = 20 + # infer_context_length = 8 + + # game_segment_length = 200 + # num_unroll_steps = 16 + # infer_context_length = 8 + + # num_unroll_steps = 4 # TODO + # infer_context_length = 2 + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/1000000000000 + + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + + norm_type ="BN" + # norm_type ="LN" + + # ====== only for debug ===== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 10 + # batch_size = 5 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, + frame_stack_num=4, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=100000, ), ), ), # 100k + # sample_type='episode', # NOTE: very important for memory env + model=dict( + # observation_shape=(3, 64, 64), + + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, + frame_stack_num=4, + + action_space_size=action_space_size, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + world_model_cfg=dict( + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, + frame_stack_num=4, + + game_segment_length=game_segment_length, + + norm_type=norm_type, + num_res_blocks=2, + num_channels=128, + # num_res_blocks=1, # TODO + # num_channels=64, + support_size=601, + policy_entropy_weight=5e-3, + # policy_entropy_weight=5e-2, # TODO(pu) + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + num_simulations=num_simulations, + rotary_emb=False, + # rotary_emb=True, + # final_norm_option_in_encoder='LayerNorm_Tanh', + # final_norm_option_in_obs_head="LayerNorm", + # predict_latent_loss_type='mse', + + # final_norm_option_in_encoder='L2Norm', + # final_norm_option_in_obs_head="L2Norm", + # predict_latent_loss_type='mse', + + final_norm_option_in_encoder="LayerNorm", + final_norm_option_in_obs_head="LayerNorm", + predict_latent_loss_type='mse', + + # final_norm_option_in_encoder="SimNorm", + # final_norm_option_in_obs_head="SimNorm", + # predict_latent_loss_type='group_kl', + + # weight_decay=1e-2, + latent_norm_loss=True, + + # latent_norm_loss=False, + weight_decay=1e-4, # TODO + + use_priority=True, # TODO(pu): test + + ), + ), + # gradient_scale=True, #TODO + gradient_scale=False, #TODO + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + use_augmentation=False, # TODO + + use_priority=True, # TODO(pu): test + priority_prob_alpha=1, + priority_prob_beta=1, + + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(2.5e4), + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + # target_model_update_option="hard", + target_update_freq=100, + + target_model_update_option="soft", + # target_update_theta=0.005, # TODO + # target_update_theta=0.01, + target_update_theta=0.05, + + learning_rate=0.0001, + # learning_rate=0.0003, # TODO + + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + num_segments=num_segments, + td_steps=5, + train_start_after_envsteps=0, + # train_start_after_envsteps=2000, # TODO + game_segment_length=game_segment_length, + grad_clip_value=5, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + # eval_freq=int(1e4), # TODO + # eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_unizero_config = EasyDict(atari_unizero_config) + main_config = atari_unizero_config + + atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_create_config = EasyDict(atari_unizero_create_config) + create_config = atari_unizero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_unizero_segment + main_config.exp_name = f'data_unizero_longrun_20250901/{env_id[:-14]}/{env_id[:-14]}_uz_in-value-reward-head-ln2_stack4_per_lnlw1e-4_enc-BN_fix-init-recur_encoder-head-ln_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250901/{env_id[:-14]}/{env_id[:-14]}_uz_per_grad-scale_pew005_fix-init-recur_encoder-head-ln_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_per_lnlw0001_pew005_fix-init-recur_encoder-head-ln_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_current-next-latent_fix-init-recur_encoder-head-simnorm_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_current-next-latent_latent-norm-weight001_fix-init-recur_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c50_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_current-next-latent_fix-init-recur_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c50_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_channel64_fix-init-recur_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c50_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_wd1e-2_fix-init-recur_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_LN-noaffine_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_simnorm_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + # args.env = 'PongNoFrameskip-v4' + + args.env = 'MsPacmanNoFrameskip-v4' + # args.env = 'QbertNoFrameskip-v4' + # args.env = 'SeaquestNoFrameskip-v4' + + # args.env = 'SpaceInvadersNoFrameskip-v4' + + # args.env = 'BeamRiderNoFrameskip-v4' + # args.env = 'GravitarNoFrameskip-v4' + + # args.env = 'BreakoutNoFrameskip-v4' + + + args.seed = 0 + + + main(args.env, args.seed) + + """ + export CUDA_VISIBLE_DEVICES=5 + cd /fs-computility/niuyazhe/puyuan/code/LightZero + python /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_unizero_segment_config.py + """ diff --git a/zoo/atari/config/atari_unizero_segment_stack4_config_bkp.py b/zoo/atari/config/atari_unizero_segment_stack4_config_bkp.py new file mode 100644 index 000000000..8398a5bfe --- /dev/null +++ b/zoo/atari/config/atari_unizero_segment_stack4_config_bkp.py @@ -0,0 +1,245 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map +# # 在main文件开始,通过全局变量来控制是否处于调试状态 +# global DEBUG_ENABLED;DEBUG_ENABLED = True + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + # collector_env_num = 8 + # num_segments = 8 + # evaluator_env_num = 3 + + collector_env_num = 1 + num_segments = 1 + evaluator_env_num = 1 + + game_segment_length = 20 + num_simulations = 50 + collect_num_simulations = 25 + eval_num_simulations = 50 + # max_env_step = int(5e5) + max_env_step = int(50e6) + batch_size = 256 + # batch_size = 64 + num_layers = 2 + # replay_ratio = 0.25 + replay_ratio = 0.1 + + num_unroll_steps = 10 + infer_context_length = 4 + + # num_unroll_steps = 4 # TODO + # infer_context_length = 2 + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/1000000000000 + + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + + # norm_type ="BN" + norm_type ="LN" + + + # ====== only for debug ===== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 10 + # batch_size = 5 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + # observation_shape=(3, 64, 64), + # gray_scale=False, + + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, + frame_stack_num=4, + + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=100000, ), ), ), # 100k + model=dict( + # observation_shape=(3, 64, 64), + + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, + frame_stack_num=4, + + action_space_size=action_space_size, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + world_model_cfg=dict( + observation_shape=(12, 64, 64), + image_channel=3, + gray_scale=False, + frame_stack_num=4, + + norm_type=norm_type, + num_res_blocks=2, + num_channels=128, + support_size=601, + policy_entropy_weight=5e-3, + # policy_entropy_weight=5e-2, # TODO(pu) + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + num_simulations=num_simulations, + rotary_emb=False, + # final_norm_option_in_encoder='LayerNorm_Tanh', + # final_norm_option_in_obs_head="LayerNorm", + # predict_latent_loss_type='mse', + + # final_norm_option_in_encoder='L2Norm', + # final_norm_option_in_obs_head="L2Norm", + # predict_latent_loss_type='mse', + + final_norm_option_in_encoder="LayerNorm", + final_norm_option_in_obs_head="LayerNorm", + predict_latent_loss_type='mse', + + # final_norm_option_in_encoder="SimNorm", + # final_norm_option_in_obs_head="SimNorm", + # predict_latent_loss_type='group_kl', + ), + ), + # gradient_scale=True, #TODO + gradient_scale=False, #TODO + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + use_augmentation=False, # TODO + use_priority=False, + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(2.5e4), + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + # target_model_update_option="hard", + target_update_freq=100, + + target_model_update_option="soft", + # target_update_theta=0.005, # TODO + # target_update_theta=0.01, + target_update_theta=0.05, + + learning_rate=0.0001, + num_simulations=50, # for reanalyze + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + num_segments=num_segments, + td_steps=5, + # train_start_after_envsteps=0, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + grad_clip_value=5, + replay_buffer_size=int(1e6), + # eval_freq=int(5e3), + # eval_freq=int(1e4), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_unizero_config = EasyDict(atari_unizero_config) + main_config = atari_unizero_config + + atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_create_config = EasyDict(atari_unizero_create_config) + create_config = atari_unizero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_unizero_segment + main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_stack4_envnum1_matchvalue-none_fixreset_encoder-LN-head-LN_soft-target-005_encoder-LN_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_encoder-LN-head-LN_soft-target-005_fix-reset-v2_collect-forward-noreset_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_encoder-LN-head-LN_soft-target-005_encoder-LN_act-pos-maxnorm1-encoder-l2norm_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_encoder-LN-head-LN_soft-target-005_encoder-LN_act-pos-maxnorm1_muzero-loss-weight__brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_encoder-LN-head-LN-gradscale_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_LN-noaffine_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + # main_config.exp_name = f'data_unizero_longrun_20250812/{env_id[:-14]}/{env_id[:-14]}_uz_simnorm_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}' + train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + args.env = 'MsPacmanNoFrameskip-v4' + # args.env = 'QbertNoFrameskip-v4' + + # args.env = 'SpaceInvadersNoFrameskip-v4' + # args.env = 'BeamRiderNoFrameskip-v4' + # args.env = 'GravitarNoFrameskip-v4' + + + # args.env = 'SeaquestNoFrameskip-v4' + # args.env = 'BreakoutNoFrameskip-v4' + + + args.seed = 0 + + + main(args.env, args.seed) + + """ + export CUDA_VISIBLE_DEVICES=7 + cd /fs-computility/niuyazhe/puyuan/code/LightZero + python /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_unizero_segment_stack4_config.py + """ diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 8bc491674..1445c1167 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -175,10 +175,13 @@ def step(self, action: int) -> BaseEnvTimestep: self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward self._timestep += 1 - # logging.info(f'self._timestep: {self._timestep}') + if self._timestep %100==0: + logging.info(f'self._timestep: {self._timestep}') observation = self.observe() if done: logging.info(f'one episode done! total episode length is: {self._timestep}') + logging.info(f'one episode done! self._eval_episode_return is: {self._eval_episode_return}') + info['eval_episode_return'] = self._eval_episode_return return BaseEnvTimestep(observation, self.reward, done, info) @@ -254,8 +257,17 @@ def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.collect_max_episode_steps + cfg.episode_life = True cfg.clip_rewards = True + + # only for save buffer TODO ================== + # cfg.episode_life = False + # cfg.clip_rewards = False + + # cfg.episode_life = False + # cfg.clip_rewards = True + return [cfg for _ in range(collector_env_num)] @staticmethod diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py index f25eead6e..ae5153ebc 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py @@ -74,6 +74,7 @@ def main(env_id, seed): continuous_action_space=continuous_action_space, num_of_sampled_actions=K, model_type='mlp', + norm_type=norm_type, world_model_cfg=dict( policy_loss_type='kl', obs_type='vector',