diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index f17126527..f3a72c49f 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -10,4 +10,5 @@ from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment +from .train_unizero_segment_with_reward_model import train_unizero_segment_with_reward_model from .utils import * diff --git a/lzero/entry/train_unizero_segment_with_reward_model.py b/lzero/entry/train_unizero_segment_with_reward_model.py new file mode 100644 index 000000000..99e0f720d --- /dev/null +++ b/lzero/entry/train_unizero_segment_with_reward_model.py @@ -0,0 +1,258 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional + +import torch +import wandb +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from lzero.reward_model.rnd_reward_model import RNDRewardModel +from .utils import random_collect, calculate_update_per_collect + +timer = EasyTimer() + +def train_unizero_segment_with_reward_model( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for UniZero (with muzero_segment_collector and buffer reanalyze trick), proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" + assert cfg.policy.use_rnd_model, "cfg.policy.use_rnd_model must be True to use RND reward model" + + # Import the correct GameBuffer class based on the policy type + game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} + + GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), + game_buffer_classes[create_cfg.policy.type]) + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create worker components: learner, collector, evaluator, replay buffer, commander + tb_logger = None + if get_rank() == 0: + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # MCTS+RL algorithms related core code + policy_config = cfg.policy + replay_buffer = GameBuffer(policy_config) + collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, + policy_config=policy_config) + evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) + + + + # ============================================================== + # 新增: 初始化 RND 奖励模型 + # RNDRewardModel 需要策略模型中的表征网络(作为预测器)和目标表征网络(作为固定目标) + # 对于 UniZero,tokenizer 扮演了表征网络的功能。 + # ============================================================== + reward_model = RNDRewardModel( + config=cfg.reward_model, + device=policy.collect_mode.get_attribute('device'), + tb_logger=tb_logger, + exp_name=cfg.exp_name, + representation_network=policy._learn_model.representation_network, + target_representation_network=policy._target_model_for_intrinsic_reward.representation_network, + use_momentum_representation_network=cfg.policy.use_momentum_representation_network, + bp_update_sync=cfg.policy.bp_update_sync, + multi_gpu=cfg.policy.multi_gpu, + ) + + + # Learner's before_run hook + learner.call_hook('before_run') + + if cfg.policy.use_wandb and get_rank() == 0: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + # Collect random data before training + if cfg.policy.random_collect_data: + random_data = random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + try: + reward_model.warmup_with_random_segments(random_data) + except Exception as e: + logging.exception(f"Failed to warm up RND normalization using random data: {e}") + raise + batch_size = policy._cfg.batch_size + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + if cfg.policy.multi_gpu: + # Get current world size and rank + world_size = get_world_size() + rank = get_rank() + else: + world_size = 1 + rank = 0 + + while True: + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + # Set temperature for visit count distributions + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + # Configure epsilon for epsilon-greedy exploration + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # Evaluate policy performance + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep, reward_model=reward_model) + if stop: + break + + # Collect new data + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + update_per_collect = calculate_update_per_collect(cfg, new_data, world_size) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze buffer + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch + if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + with timer: + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + # Train the policy if sufficient data is available + if collector.envstep > cfg.policy.train_start_after_envsteps: + if cfg.policy.sample_type == 'episode': + data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size + else: + data_sufficient = replay_buffer.get_num_of_transitions() > batch_size + if not data_sufficient: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' + ) + continue + + for i in range(update_per_collect): + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + with timer: + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + train_data_augmented = reward_model.estimate(train_data) + train_data_augmented.append(learner.train_iter) + + log_vars = learner.train(train_data_augmented, collector.envstep) + reward_model.train_with_policy_batch(train_data) + logging.info(f'[{i}/{update_per_collect}]: learner and reward_model ended training step.') + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Check stopping criteria + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + if cfg.policy.use_wandb: + wandb.finish() + return policy diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 3da03312b..2de74ce32 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -147,8 +147,8 @@ def random_collect( collector_env: 'BaseEnvManager', # noqa replay_buffer: 'IBuffer', # noqa postprocess_data_fn: Optional[Callable] = None -) -> None: # noqa - assert policy_cfg.random_collect_episode_num > 0 +) -> list: # noqa + assert policy_cfg.random_collect_data, "random_collect_data should be True." random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space) # set the policy to random policy @@ -159,7 +159,7 @@ def random_collect( collect_kwargs = {'temperature': 1, 'epsilon': 0.0} # Collect data by default config n_sample/n_episode. - new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, + new_data = collector.collect(train_iter=0, policy_kwargs=collect_kwargs) if postprocess_data_fn is not None: @@ -172,6 +172,7 @@ def random_collect( # restore the policy collector.reset_policy(policy.collect_mode) + return new_data def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index b8998acb9..1f18caf56 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -630,3 +630,34 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values = np.asarray(batch_target_values) return batch_rewards, batch_target_values + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`List[np.ndarray]`): training data to be updated priority. + - batch_priorities (:obj:`np.ndarray`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + """ + # TODO: NOTE: -4 is batch_index_list + indices = train_data[0][-4] + metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + # ==================== START OF FINAL FIX ==================== + + # FIX 1: Handle ValueError by using the first timestamp of the segment for comparison. + first_transition_time = metas['make_time'][i][0] + + if first_transition_time > self.clear_time: + # FIX 2: Handle IndexError by converting the float index to an integer before use. + idx = int(indices[i]) + prio = metas['batch_priorities'][i] + + # Now, idx is a valid integer index. + self.game_pos_priorities[idx] = prio + + # ===================== END OF FINAL FIX ===================== \ No newline at end of file diff --git a/lzero/model/common.py b/lzero/model/common.py index 7b1bbeeae..481a4bb6f 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -641,11 +641,11 @@ def __init__( self.embedding_dim = embedding_dim if self.observation_shape[1] == 64: - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False) elif self.observation_shape[1] in [84, 96]: - self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) - + self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False) + self.final_norm_option_in_encoder = final_norm_option_in_encoder if self.final_norm_option_in_encoder == 'LayerNorm': self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) @@ -678,7 +678,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(-1, self.embedding_dim) - # NOTE: very important for training stability. x = self.final_norm(x) return x diff --git a/lzero/model/unizero_world_models/kv_cache_manager.py b/lzero/model/unizero_world_models/kv_cache_manager.py new file mode 100644 index 000000000..066586a6a --- /dev/null +++ b/lzero/model/unizero_world_models/kv_cache_manager.py @@ -0,0 +1,473 @@ +""" +KV Cache Manager for UniZero World Model +========================================= + +This module provides a unified, robust, and extensible KV cache management system +for the UniZero world model. It replaces the scattered cache logic with a clean, +well-tested abstraction. + +Author: Claude Code +Date: 2025-10-24 +""" + +import logging +from typing import Dict, List, Optional, Tuple, Any, Callable +from dataclasses import dataclass, field +from enum import Enum +import torch +from collections import OrderedDict + +# Assuming kv_caching is in the same directory or accessible +from .kv_caching import KeysValues + + +logger = logging.getLogger(__name__) + + +class EvictionStrategy(Enum): + """Cache eviction strategies.""" + FIFO = "fifo" # First In First Out (循环覆盖) + LRU = "lru" # Least Recently Used + PRIORITY = "priority" # 基于优先级 + + +@dataclass +class CacheStats: + """Statistics for cache performance monitoring.""" + hits: int = 0 + misses: int = 0 + evictions: int = 0 + total_queries: int = 0 + + @property + def hit_rate(self) -> float: + """Calculate hit rate.""" + if self.total_queries == 0: + return 0.0 + return self.hits / self.total_queries + + @property + def miss_rate(self) -> float: + """Calculate miss rate.""" + return 1.0 - self.hit_rate + + def reset(self): + """Reset all statistics.""" + self.hits = 0 + self.misses = 0 + self.evictions = 0 + self.total_queries = 0 + + def __repr__(self) -> str: + return (f"CacheStats(hits={self.hits}, misses={self.misses}, " + f"evictions={self.evictions}, hit_rate={self.hit_rate:.2%})") + + +class KVCachePool: + """ + A fixed-size pool for storing KeysValues objects. + + This class manages a pre-allocated pool of KeysValues objects and provides + efficient storage and retrieval mechanisms with configurable eviction strategies. + + Args: + pool_size: Maximum number of KV caches to store + eviction_strategy: Strategy for cache eviction + enable_stats: Whether to collect statistics + name: Name for this cache pool (for logging) + """ + + def __init__( + self, + pool_size: int, + eviction_strategy: EvictionStrategy = EvictionStrategy.FIFO, + enable_stats: bool = True, + name: str = "default" + ): + if pool_size <= 0: + raise ValueError(f"pool_size must be positive, got {pool_size}") + + self.pool_size = pool_size + self.eviction_strategy = eviction_strategy + self.enable_stats = enable_stats + self.name = name + + # Core data structures + self._pool: List[Optional[KeysValues]] = [None] * pool_size + self._key_to_index: Dict[int, int] = {} # cache_key -> pool_index + self._index_to_key: List[Optional[int]] = [None] * pool_size # pool_index -> cache_key + + # Eviction strategy specific data + self._next_index: int = 0 # For FIFO + self._access_order: OrderedDict = OrderedDict() # For LRU + self._priorities: Dict[int, float] = {} # For PRIORITY + + # Statistics + self.stats = CacheStats() if enable_stats else None + + logger.info(f"Initialized KVCachePool '{name}' with size={pool_size}, " + f"strategy={eviction_strategy.value}") + + def get(self, cache_key: int) -> Optional[KeysValues]: + """ + Retrieve a cached KeysValues object. + + Args: + cache_key: The hash key for the cache + + Returns: + The cached KeysValues object if found, None otherwise + """ + if self.enable_stats: + self.stats.total_queries += 1 + + pool_index = self._key_to_index.get(cache_key) + + if pool_index is not None: + # Cache hit + if self.enable_stats: + self.stats.hits += 1 + + # Update access order for LRU + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order.move_to_end(cache_key) + + logger.debug(f"[{self.name}] Cache HIT for key={cache_key}, index={pool_index}") + return self._pool[pool_index] + else: + # Cache miss + if self.enable_stats: + self.stats.misses += 1 + + logger.debug(f"[{self.name}] Cache MISS for key={cache_key}") + return None + + def set(self, cache_key: int, kv_cache: KeysValues) -> int: + """ + Store a KeysValues object in the cache. + + Args: + cache_key: The hash key for the cache + kv_cache: The KeysValues object to store + + Returns: + The pool index where the cache was stored + """ + # ==================== BUG FIX: Defensive Deep Copy ==================== + # CRITICAL: Always clone the input to prevent cache corruption. + # This provides an additional layer of protection in case the caller + # forgets to clone. The clone operation ensures that the stored cache + # is independent from the caller's object, preventing unintended mutations. + kv_cache_copy = kv_cache.clone() + # ======================================================================= + + # Check if key already exists + if cache_key in self._key_to_index: + # Update existing entry + pool_index = self._key_to_index[cache_key] + self._pool[pool_index] = kv_cache_copy # Store cloned copy + + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order.move_to_end(cache_key) + + logger.debug(f"[{self.name}] Updated cache for key={cache_key} at index={pool_index}") + return pool_index + + # Find a slot for new entry + pool_index = self._find_slot_for_new_entry(cache_key) + + # Evict old entry if necessary + old_key = self._index_to_key[pool_index] + if old_key is not None: + self._evict(old_key, pool_index) + + # Store new entry (already cloned above) + self._pool[pool_index] = kv_cache_copy + self._key_to_index[cache_key] = pool_index + self._index_to_key[pool_index] = cache_key + + # Update access tracking for LRU + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order[cache_key] = True + + logger.debug(f"[{self.name}] Stored cache for key={cache_key} at index={pool_index}") + return pool_index + + def _find_slot_for_new_entry(self, cache_key: int) -> int: + """Find an appropriate slot for a new cache entry based on eviction strategy.""" + if self.eviction_strategy == EvictionStrategy.FIFO: + # Simple circular buffer + pool_index = self._next_index + self._next_index = (self._next_index + 1) % self.pool_size + return pool_index + + elif self.eviction_strategy == EvictionStrategy.LRU: + # Find LRU slot + if len(self._key_to_index) < self.pool_size: + # Pool not full, find first empty slot + for i in range(self.pool_size): + if self._index_to_key[i] is None: + return i + + # Evict LRU (first item in OrderedDict) + lru_key = next(iter(self._access_order)) + return self._key_to_index[lru_key] + + elif self.eviction_strategy == EvictionStrategy.PRIORITY: + # Find lowest priority slot + if len(self._key_to_index) < self.pool_size: + # Pool not full + for i in range(self.pool_size): + if self._index_to_key[i] is None: + return i + + # Evict lowest priority + min_priority_key = min(self._priorities, key=self._priorities.get) + return self._key_to_index[min_priority_key] + + else: + raise ValueError(f"Unknown eviction strategy: {self.eviction_strategy}") + + def _evict(self, cache_key: int, pool_index: int): + """Evict a cache entry.""" + if self.enable_stats: + self.stats.evictions += 1 + + # Remove from tracking structures + del self._key_to_index[cache_key] + self._index_to_key[pool_index] = None + + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order.pop(cache_key, None) + + if self.eviction_strategy == EvictionStrategy.PRIORITY: + self._priorities.pop(cache_key, None) + + logger.debug(f"[{self.name}] Evicted key={cache_key} from index={pool_index}") + + def clear(self): + """Clear all cache entries.""" + self._pool = [None] * self.pool_size + self._key_to_index.clear() + self._index_to_key = [None] * self.pool_size + self._next_index = 0 + self._access_order.clear() + self._priorities.clear() + + if self.enable_stats: + # Don't reset stats on clear, user can call stats.reset() explicitly + pass + + # logger.info(f"[{self.name}] Cleared all cache entries") + + def __len__(self) -> int: + """Return the number of cached entries.""" + return len(self._key_to_index) + + def __repr__(self) -> str: + stats_str = f", {self.stats}" if self.enable_stats else "" + return (f"KVCachePool(name='{self.name}', size={len(self)}/{self.pool_size}, " + f"strategy={self.eviction_strategy.value}{stats_str})") + + +class KVCacheManager: + """ + Unified KV Cache Manager for World Model. + + This class manages multiple cache pools for different inference scenarios: + - Initial inference caches (per-environment) + - Recurrent inference caches (for MCTS) + - World model caches (temporary batch caches) + + Args: + config: World model configuration + env_num: Number of environments + enable_stats: Whether to enable statistics collection + clear_recur_log_freq: How often to log 'clear_recur_cache' calls. + clear_all_log_freq: How often to log 'clear_all' calls. + """ + + def __init__( + self, + config, + env_num: int, + enable_stats: bool = True, + clear_recur_log_freq: int = 1000, # <--- RENAMED & MODIFIED + clear_all_log_freq: int = 100 # <--- NEW + ): + self.config = config + self.env_num = env_num + self.enable_stats = enable_stats + + # --- Throttling parameters and counters --- + self.clear_recur_log_freq = clear_recur_log_freq + self.clear_all_log_freq = clear_all_log_freq + self._clear_recur_counter = 0 + self._clear_all_counter = 0 # <--- NEW + + # Initialize cache pools + self._init_cache_pools() + + # These lists store KeysValues objects, not integers + # Used in world model's trim_and_pad_kv_cache for batch processing + self.keys_values_wm_list: List[KeysValues] = [] + self.keys_values_wm_size_list: List[int] = [] + + logger.info(f"Initialized KVCacheManager for {env_num} environments") + + def _init_cache_pools(self): + """Initialize all cache pools.""" + # Initial inference pools (one per environment) + init_pool_size = int(self.config.game_segment_length) + self.init_pools: List[KVCachePool] = [] + for env_id in range(self.env_num): + pool = KVCachePool( + pool_size=init_pool_size, + eviction_strategy=EvictionStrategy.FIFO, + enable_stats=self.enable_stats, + name=f"init_env{env_id}" + ) + self.init_pools.append(pool) + + # Recurrent inference pool (shared across all environments) + num_simulations = getattr(self.config, 'num_simulations', 50) + recur_pool_size = int(num_simulations * self.env_num) + self.recur_pool = KVCachePool( + pool_size=recur_pool_size, + eviction_strategy=EvictionStrategy.FIFO, + enable_stats=self.enable_stats, + name="recurrent" + ) + + # World model pool (temporary) + wm_pool_size = self.env_num + self.wm_pool = KVCachePool( + pool_size=wm_pool_size, + eviction_strategy=EvictionStrategy.FIFO, + enable_stats=self.enable_stats, + name="world_model" + ) + + def get_init_cache(self, env_id: int, cache_key: int) -> Optional[KeysValues]: + """Get cache from initial inference pool.""" + if env_id < 0 or env_id >= self.env_num: + raise ValueError(f"Invalid env_id: {env_id}, must be in [0, {self.env_num})") + return self.init_pools[env_id].get(cache_key) + + def set_init_cache(self, env_id: int, cache_key: int, kv_cache: KeysValues) -> int: + """Set cache in initial inference pool.""" + if env_id < 0 or env_id >= self.env_num: + raise ValueError(f"Invalid env_id: {env_id}, must be in [0, {self.env_num})") + return self.init_pools[env_id].set(cache_key, kv_cache) + + def get_recur_cache(self, cache_key: int) -> Optional[KeysValues]: + """Get cache from recurrent inference pool.""" + return self.recur_pool.get(cache_key) + + def set_recur_cache(self, cache_key: int, kv_cache: KeysValues) -> int: + """Set cache in recurrent inference pool.""" + return self.recur_pool.set(cache_key, kv_cache) + + def get_wm_cache(self, cache_key: int) -> Optional[KeysValues]: + """Get cache from world model pool.""" + return self.wm_pool.get(cache_key) + + def set_wm_cache(self, cache_key: int, kv_cache: KeysValues) -> int: + """Set cache in world model pool.""" + return self.wm_pool.set(cache_key, kv_cache) + + def hierarchical_get(self, env_id: int, cache_key: int) -> Optional[KeysValues]: + """ + Perform hierarchical cache lookup: init_pool -> recur_pool. + + This method encapsulates the two-level lookup strategy: + 1. First try to find in environment-specific init_infer cache + 2. If not found, fallback to global recurrent_infer cache + + Arguments: + - env_id (:obj:`int`): Environment ID for init cache lookup + - cache_key (:obj:`int`): Cache key to lookup + + Returns: + - kv_cache (:obj:`Optional[KeysValues]`): Found cache or None + """ + # Step 1: Try init_infer cache first (per-environment) + kv_cache = self.get_init_cache(env_id, cache_key) + if kv_cache is not None: + return kv_cache + + # Step 2: If not found, try recurrent_infer cache (global) + return self.get_recur_cache(cache_key) + + def clear_all(self): # <--- MODIFIED METHOD + """Clear all cache pools, with throttled logging.""" + # Core clearing actions always execute. + for pool in self.init_pools: + pool.clear() + self.recur_pool.clear() + self.wm_pool.clear() + self.keys_values_wm_list.clear() + self.keys_values_wm_size_list.clear() + + # --- Throttled Logging Logic --- + self._clear_all_counter += 1 + if self.clear_all_log_freq > 0 and self._clear_all_counter % self.clear_all_log_freq == 0: + logger.info( + f"Cleared all KV caches (this message appears every " + f"{self.clear_all_log_freq} calls, total calls: {self._clear_all_counter})" + ) + + def clear_init_caches(self): + """Clear only initial inference caches.""" + for pool in self.init_pools: + pool.clear() + logger.info("Cleared initial inference caches") + + def clear_recur_cache(self): + """Clear only recurrent inference cache, with throttled logging.""" + # The core cache clearing action always executes. + self.recur_pool.clear() + + # --- Throttled Logging Logic --- + self._clear_recur_counter += 1 + # Only log if frequency is positive and the counter is a multiple of the frequency. + if self.clear_recur_log_freq > 0 and self._clear_recur_counter % self.clear_recur_log_freq == 0: + logger.info( + f"Cleared recurrent inference cache (this message appears every " + f"{self.clear_recur_log_freq} calls, total calls: {self._clear_recur_counter})" + ) + + def get_stats_summary(self) -> Dict[str, Any]: + """Get statistics summary for all pools.""" + if not self.enable_stats: + return {"stats_enabled": False} + + summary = { + "stats_enabled": True, + "init_pools": {}, + "recur_pool": str(self.recur_pool.stats), + "wm_pool": str(self.wm_pool.stats), + } + + for env_id, pool in enumerate(self.init_pools): + summary["init_pools"][f"env_{env_id}"] = str(pool.stats) + + return summary + + def reset_stats(self): + """Reset statistics for all pools.""" + if not self.enable_stats: + return + + for pool in self.init_pools: + pool.stats.reset() + self.recur_pool.stats.reset() + self.wm_pool.stats.reset() + logger.info("Reset all cache statistics") + + def __repr__(self) -> str: + init_sizes = [len(pool) for pool in self.init_pools] + return (f"KVCacheManager(env_num={self.env_num}, " + f"init_caches={init_sizes}, " + f"recur_cache={len(self.recur_pool)}/{self.recur_pool.pool_size}, " + f"wm_cache={len(self.wm_pool)}/{self.wm_pool.pool_size})") \ No newline at end of file diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index 28b7b0ba2..f52f8871e 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -203,6 +203,51 @@ def prune(self, mask: np.ndarray) -> None: """ for kv_cache in self._keys_values: kv_cache.prune(mask) + + def clone(self) -> "KeysValues": + """ + Overview: + Creates a deep copy of this KeysValues object. + + This method is critical for preventing cache corruption. When a cached KeysValues object + is retrieved and used in transformer forward passes, the transformer modifies it in-place. + Without cloning, this would pollute the original cache, causing incorrect predictions. + + Returns: + - cloned_kv (:obj:`KeysValues`): A new KeysValues object with copied data. + """ + if not self._keys_values: + # Handle empty case + raise ValueError("Cannot clone an empty KeysValues object") + + # Get parameters from the first layer's cache + first_kv_cache = self._keys_values[0] + num_samples, num_heads, _, head_dim = first_kv_cache.shape + max_tokens = first_kv_cache._k_cache._max_tokens + embed_dim = num_heads * head_dim + num_layers = len(self._keys_values) + device = first_kv_cache._k_cache._device + + # Create a new KeysValues object with the same structure + cloned_kv = KeysValues( + num_samples=num_samples, + num_heads=num_heads, + max_tokens=max_tokens, + embed_dim=embed_dim, + num_layers=num_layers, + device=device + ) + + # Deep copy each layer's cache data + for src_layer, dst_layer in zip(self._keys_values, cloned_kv._keys_values): + # Copy the key and value cache tensors using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + # Copy the size information + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + return cloned_kv class AssignWithoutInplaceCheck(torch.autograd.Function): diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 99c841cbe..423df5a77 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -181,15 +181,29 @@ def calculate_cuda_memory_gb(past_keys_values_cache, num_layers: int): def hash_state(state): """ - Hash the state vector. + Overview: + Computes a fast and robust hash for a NumPy array state. + + Why this is optimal: + 1. Algorithm (`xxhash.xxh64`): Uses one of the fastest non-cryptographic hash + functions available, ideal for performance-critical applications like caching. + 2. Input Preparation (`state.tobytes()`): Ensures correctness by creating a + canonical byte representation of the array. This guarantees that two + logically identical arrays will produce the same hash, regardless of their + internal memory layout (e.g., C-contiguous, F-contiguous, or strided views). + 3. Output Format (`.intdigest()`): Directly produces an integer hash value, + which is the most efficient key type for Python dictionaries, avoiding the + overhead of string keys. Arguments: - state: The state vector to be hashed. + - state (np.ndarray): The state array to be hashed. Returns: - The hash value of the state vector. + - int: A 64-bit integer hash of the state. """ - # Use xxhash for faster hashing - return xxhash.xxh64(state).hexdigest() + # Ensure the array is contiguous in memory before converting to bytes, + # although .tobytes() handles this, being explicit can sometimes be clearer. + # For simplicity and since .tobytes() defaults to C-order, we can rely on it. + return xxhash.xxh64(state.tobytes()).intdigest() @dataclass class WorldModelOutput: @@ -201,7 +215,7 @@ class WorldModelOutput: logits_value: torch.FloatTensor -def init_weights(module, norm_type='BN'): +def init_weights(module, norm_type='BN', liner_weight_zero=False): """ Initialize the weights of the module based on the specified normalization type. @@ -209,9 +223,16 @@ def init_weights(module, norm_type='BN'): module (nn.Module): The module to initialize. norm_type (str): The type of normalization to use ('BN' for BatchNorm, 'LN' for LayerNorm). """ - if isinstance(module, (nn.Linear, nn.Embedding)): + if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: + elif isinstance(module, nn.Linear): + if norm_type == 'BN': + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + print("Init Linear using kaiming normal for BN") + elif norm_type == 'LN': + nn.init.xavier_uniform_(module.weight) + print("Init Linear using xavier uniform for LN") + if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") @@ -228,13 +249,6 @@ def init_weights(module, norm_type='BN'): elif norm_type == 'LN': nn.init.xavier_uniform_(module.weight) print(f"Init nn.Conv2d using xavier uniform for LN") - elif isinstance(module, nn.Linear): - if norm_type == 'BN': - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - print("Init Linear using kaiming normal for BN") - elif norm_type == 'LN': - nn.init.xavier_uniform_(module.weight) - print("Init Linear using xavier uniform for LN") class LossWithIntermediateLosses: @@ -260,6 +274,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu if not continuous_action_space: # like EZV2, for atari and memory self.obs_loss_weight = 10 + # self.obs_loss_weight = 2 self.value_loss_weight = 0.5 self.reward_loss_weight = 1. self.policy_loss_weight = 1. @@ -267,6 +282,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu else: # like TD-MPC2 for DMC self.obs_loss_weight = 10 + # self.obs_loss_weight = 2 self.value_loss_weight = 0.1 self.reward_loss_weight = 0.1 self.policy_loss_weight = 0.1 diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 7f1a0f68e..1e0a1367a 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -16,6 +16,15 @@ from .transformer import Transformer, TransformerConfig from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state +from collections import OrderedDict, defaultdict +logging.getLogger().setLevel(logging.DEBUG) +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import os +import datetime + + logging.getLogger().setLevel(logging.DEBUG) @@ -115,9 +124,6 @@ def custom_init(module): self._initialize_last_layer() - # Cache structures - self._initialize_cache_structures() - # Projection input dimension self._initialize_projection_input_dim() @@ -129,19 +135,24 @@ def custom_init(module): self.latent_recon_loss = torch.tensor(0., device=self.device) self.perceptual_loss = torch.tensor(0., device=self.device) + + # 先设置为game_segment_length,以保持self.shared_pool_init_infer都是有效的kv + # TODO: 非常重要,应该改为和segment_length一样 + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? # TODO: check the size of the shared pool # for self.kv_cache_recurrent_infer # If needed, recurrent_infer should store the results of the one MCTS search. self.num_simulations = getattr(self.config, 'num_simulations', 50) - self.shared_pool_size = int(self.num_simulations*self.env_num) - self.shared_pool_recur_infer = [None] * self.shared_pool_size + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur self.shared_pool_index = 0 + + # Cache structures + self._initialize_cache_structures() # for self.kv_cache_init_infer # In contrast, init_infer only needs to retain the results of the most recent step. - # self.shared_pool_size_init = int(2*self.env_num) - self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] @@ -151,6 +162,117 @@ def custom_init(module): self.shared_pool_index_wm = 0 self.reanalyze_phase = False + + def _analyze_latent_representation( + self, + latent_states: torch.Tensor, + timesteps: torch.Tensor, + game_states: torch.Tensor, + predicted_values: torch.Tensor, + predicted_rewards: torch.Tensor, + step_counter: int + ): + """ + 分析并记录 latent states 的统计信息和t-SNE可视化。 + 【新功能】:在t-SNE图上显示对应的游戏图像,并标注预测的Value和Reward。 + 【已修改】:如果保存路径已存在同名文件,则在文件名后附加时间戳。 + + Args: + latent_states (torch.Tensor): Encoder的输出, shape (B*L, 1, E) + timesteps (torch.Tensor): 对应的时间步, shape (B, L) + game_states (torch.Tensor): 原始的游戏观测, shape (B, L, C, H, W) + predicted_values (torch.Tensor): 预测的标量Value, shape (B*L,) + predicted_rewards (torch.Tensor): 预测的标量Reward, shape (B*L,) + step_counter (int): 全局训练步数 + """ + # ... (统计分析部分保持不变) ... + # (确保 latent_states 和 game_states 的形状为 (N, ...)) + if latent_states.dim() > 2: + latent_states = latent_states.reshape(-1, latent_states.shape[-1]) + num_c, num_h, num_w = game_states.shape[-3:] + game_states = game_states.reshape(-1, num_c, num_h, num_w) + + with torch.no_grad(): + l2_norm = torch.norm(latent_states, p=2, dim=1).mean() + mean = latent_states.mean() + std = latent_states.std() + print(f"[Step {step_counter}] Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}") + + # 带图像和V/R值的 t-SNE 可视化 + if step_counter >= 0: + # if step_counter > 0 and step_counter % 200 == 0: + + print(f"[Step {step_counter}] Performing t-SNE analysis with images, values, and rewards...") + + # 将数据转换到CPU + latents_np = latent_states.detach().cpu().numpy() + images_np = game_states.detach().cpu().numpy() + values_np = predicted_values.detach().cpu().numpy() + rewards_np = predicted_rewards.detach().cpu().numpy() + + tsne = TSNE(n_components=2, perplexity=30, n_iter=300, random_state=42) + tsne_results = tsne.fit_transform(latents_np) + + # --- 绘制带图像和标注的散点图 --- + + # 减少图像数量以保持清晰 + num_points_to_plot = min(len(latents_np), 70) # 减少到70个点 + indices = np.random.choice(len(latents_np), num_points_to_plot, replace=False) + + fig, ax = plt.subplots(figsize=(20, 18)) # 增大画布尺寸 + + # 先画出所有点的散点图作为背景 + ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=values_np, cmap='viridis', alpha=0.3, s=10) + + for i in indices: + x, y = tsne_results[i] + img = images_np[i].transpose(1, 2, 0) + img = np.clip(img, 0, 1) + + # 放置图像 + im = OffsetImage(img, zoom=0.7) # 稍微放大图像 + ab = AnnotationBbox(im, (x, y), frameon=True, pad=0.0, bboxprops=dict(edgecolor='none')) + ax.add_artist(ab) + + # 在图像下方添加文字标注 + text_label = f"V:{values_np[i]:.1f} R:{rewards_np[i]:.1f}" + ax.text(x, y - 1.0, text_label, ha='center', va='top', fontsize=8, color='red', + bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.5)) + + ax.update_datalim(tsne_results) + ax.autoscale() + + ax.set_title(f't-SNE of Latent States (Value as Color) at Step {step_counter}', fontsize=16) + ax.set_xlabel('t-SNE dimension 1', fontsize=12) + ax.set_ylabel('t-SNE dimension 2', fontsize=12) + + # 添加colorbar来解释背景点的颜色 + norm = plt.Normalize(values_np.min(), values_np.max()) + sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm) + sm.set_array([]) + fig.colorbar(sm, ax=ax, label='Predicted Value') + + # --- 修改部分:检查文件是否存在,如果存在则添加时间戳 --- + base_save_path = ( + f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + f'tsne_with_vr_{self.config.optim_type}_step_{step_counter}.png' + ) + + # 2. 检查文件是否存在,并确定最终保存路径 + if os.path.exists(base_save_path): + # 如果文件已存在,则生成时间戳并附加到文件名 + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + path_root, path_ext = os.path.splitext(base_save_path) + save_path = f"{path_root}_{timestamp}{path_ext}" + print(f"File '{base_save_path}' already exists. Saving to new path with timestamp.") + else: + # 如果文件不存在,则使用原始路径 + save_path = base_save_path + + # 3. 保存图像 + plt.savefig(save_path) + plt.close(fig) # 明确关闭图形对象 + print(f"t-SNE plot with V/R annotations saved to {save_path}") def _get_final_norm(self, norm_option: str) -> nn.Module: """ @@ -264,7 +386,7 @@ def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: dst_layer._v_cache._size = src_layer._v_cache._size index = self.shared_pool_index - self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size_recur return index @@ -304,7 +426,9 @@ def _initialize_patterns(self) -> None: def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: """Create head modules for the transformer.""" modules = [ + nn.LayerNorm(self.config.embed_dim), nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.LayerNorm(self.config.embed_dim), nn.GELU(approximate='tanh'), nn.Linear(self.config.embed_dim, output_dim) ] @@ -354,11 +478,55 @@ def _initialize_last_layer(self) -> None: def _initialize_cache_structures(self) -> None: """Initialize cache structures for past keys and values.""" from collections import defaultdict - self.past_kv_cache_recurrent_infer = defaultdict(dict) - self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + # ==================== Phase 1: Parallel KV Cache Systems ==================== + # Check if we should use the new KV cache manager + self.use_new_cache_manager = getattr(self.config, 'use_new_cache_manager', False) + + if self.use_new_cache_manager: + # Use new unified KV cache manager + from .kv_cache_manager import KVCacheManager + self.kv_cache_manager = KVCacheManager( + config=self.config, + env_num=self.env_num, + enable_stats=True, + clear_recur_log_freq=1000, # MCTS循环清理日志,每1000次打印一次 + clear_all_log_freq=100 # episode重置清理日志,每100次打印一次 + ) + # Keep backward compatibility references + self.keys_values_wm_list = self.kv_cache_manager.keys_values_wm_list + self.keys_values_wm_size_list = self.kv_cache_manager.keys_values_wm_size_list + + # ==================== BUG FIX: Complete Refactoring ==================== + # DO NOT initialize old system attributes when using new cache manager. + # Any code that depends on these old attributes must be refactored to use + # kv_cache_manager instead. + # + # Old attributes that are NO LONGER available in new system: + # - self.past_kv_cache_recurrent_infer + # - self.pool_idx_to_key_map_recur_infer + # - self.past_kv_cache_init_infer_envs + # - self.pool_idx_to_key_map_init_envs + # + # Migration guide: + # - For accessing init cache: use kv_cache_manager.get_init_cache(env_id, key) + # - For accessing recur cache: use kv_cache_manager.get_recur_cache(key) + # - For hierarchical lookup: use kv_cache_manager.hierarchical_get(env_id, key) + # ====================================================================== + + logging.info("✓ Using NEW KVCacheManager for cache management") + else: + # Use old cache system (original implementation) + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # 辅助数据结构,用于反向查找:pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + logging.info("Using OLD cache system (original implementation)") + # ============================================================================= - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] def _initialize_projection_input_dim(self) -> None: """Initialize the projection input dimension based on the number of observation tokens.""" @@ -831,20 +999,33 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens state_single_env = last_obs_embeddings[i] # Compute hash value using latent state for a single environment cache_key = hash_state(state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor - + # ==================== Phase 1.6: Storage Layer Integration ==================== # Retrieve cached value - cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) - if cache_index is not None: - matched_value = self.shared_pool_init_infer[i][cache_index] + if self.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager + matched_value = self.kv_cache_manager.get_init_cache(env_id=i, cache_key=cache_key) else: - matched_value = None + # OLD SYSTEM: Use legacy cache dictionaries + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None self.root_total_query_cnt += 1 if matched_value is not None: # If a matching value is found, add it to the list self.root_hit_cnt += 1 - # NOTE: deepcopy is needed because forward modifies matched_value in place - self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # ==================== BUG FIX: Cache Corruption Prevention ==================== + # Perform a deep copy because the transformer's forward pass modifies matched_value in-place. + if self.use_new_cache_manager: + # NEW SYSTEM: Use KeysValues.clone() for deep copy + cached_copy = matched_value.clone() + self.keys_values_wm_list.append(cached_copy) + else: + # OLD SYSTEM: Use custom_copy_kv_cache_to_shared_wm + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # ============================================================================= self.keys_values_wm_size_list.append(matched_value.size) else: # Reset using zero values @@ -934,7 +1115,14 @@ def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): """ # UniZero has context in the root node outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, start_pos) - self.past_kv_cache_recurrent_infer.clear() + # ==================== BUG FIX: Clear Cache Using Correct API ==================== + if self.use_new_cache_manager: + # NEW SYSTEM: Clear recurrent cache using KVCacheManager + self.kv_cache_manager.clear_recur_cache() + else: + # OLD SYSTEM: Clear using legacy attribute + self.past_kv_cache_recurrent_infer.clear() + # ============================================================================= return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) @@ -1210,14 +1398,65 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 - if is_init_infer: - # Store the latest key-value cache for initial inference - cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) - self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + # ==================== Phase 1.5: Storage Layer Integration ==================== + if self.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager for cache storage + # ==================== BUG FIX: Deep Copy Before Storage ==================== + # CRITICAL: Must clone before storing to prevent cache corruption. + # self.keys_values_wm_single_env is a shared object that gets modified. + # Without cloning, all cache entries would point to the same object, + # causing incorrect KV retrieval and training divergence. + kv_cache_to_store = self.keys_values_wm_single_env.clone() + # ============================================================================= + if is_init_infer: + # Store to per-environment init cache pool + # Note: KVCacheManager automatically handles eviction logic (FIFO/LRU) + self.kv_cache_manager.set_init_cache( + env_id=i, + cache_key=cache_key, + kv_cache=kv_cache_to_store # Store cloned copy, not reference + ) + else: + # Store to global recurrent cache pool + self.kv_cache_manager.set_recur_cache( + cache_key=cache_key, + kv_cache=kv_cache_to_store # Store cloned copy, not reference + ) else: - # Store the latest key-value cache for recurrent inference - cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) - self.past_kv_cache_recurrent_infer[cache_key] = cache_index + # OLD SYSTEM: Use legacy cache with manual eviction + if is_init_infer: + # ==================== 主动淘汰修复逻辑 ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index_init_envs[i] + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + # 确保要删除的键确实存在,避免意外错误 + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # 现在可以安全地写入新数据了 + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. 在主 cache map 和辅助列表中同时更新新的映射关系 + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key + else: + # ==================== RECURRENT INFER FIX ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + # 4. 现在可以安全地写入新数据了 + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # 5. 在主 cache map 和辅助列表中同时更新新的映射关系 + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, @@ -1245,22 +1484,48 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, # TODO: check if this is correct matched_value = None else: - # Try to retrieve the cached value from past_kv_cache_init_infer_envs - cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) - if cache_index is not None: - matched_value = self.shared_pool_init_infer[index][cache_index] + # ==================== Phase 1.6: Storage Layer Integration (Refactored) ==================== + if self.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager's hierarchical_get for unified lookup + matched_value = self.kv_cache_manager.hierarchical_get(env_id=index, cache_key=cache_key) + + # Log cache miss (统计由 KVCacheManager 自动处理) + if matched_value is None: + logging.debug(f"[NEW CACHE MISS] Not found for key={cache_key} in both init and recurrent cache.") else: - matched_value = None - - # If not found, try to retrieve from past_kv_cache_recurrent_infer - if matched_value is None: - matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + # OLD SYSTEM: Use legacy cache dictionaries and pools + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[index][cache_index] + else: + matched_value = None + + # 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找 + if matched_value is None: + # 安全地从字典中获取索引,它可能返回 None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值 + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + if recur_cache_index is None: + print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") if matched_value is not None: # If a matching cache is found, add it to the lists self.hit_count += 1 - # Perform a deep copy because the transformer's forward pass might modify matched_value in-place - self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # ==================== BUG FIX: Cache Corruption Prevention ==================== + # Perform a deep copy because the transformer's forward pass modifies matched_value in-place. + # Without cloning, the original cache in init_pool or recur_pool would be polluted, + # causing incorrect predictions in subsequent queries. + if self.use_new_cache_manager: + # NEW SYSTEM: Use KeysValues.clone() for deep copy + cached_copy = matched_value.clone() + self.keys_values_wm_list.append(cached_copy) + else: + # OLD SYSTEM: Use custom_copy_kv_cache_to_shared_wm + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # ============================================================================= self.keys_values_wm_size_list.append(matched_value.size) else: # If no matching cache is found, generate a new one using zero reset @@ -1311,7 +1576,12 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), percentage=self.dormant_threshold) - self.past_kv_cache_recurrent_infer.clear() + # ==================== BUG FIX: Clear Cache Using Correct API ==================== + if self.use_new_cache_manager: + self.kv_cache_manager.clear_recur_cache() + else: + self.past_kv_cache_recurrent_infer.clear() + # ============================================================================= self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: @@ -1329,6 +1599,56 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Forward pass to obtain predictions for observations, rewards, and policies outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + # [新增] 从模型输出中获取中间张量 x,并分离计算图 + intermediate_tensor_x = outputs.output_sequence.detach() + + global_step = kwargs.get('global_step', 0) + if global_step > 0 and global_step % 100000000000 == 0: # 20k + + with torch.no_grad(): + # 将logits转换为标量值 + # 注意:outputs的形状是(B, L, E),我们需要reshape + batch_size, seq_len = batch['actions'].shape[0], batch['actions'].shape[1] + + pred_val_logits = outputs.logits_value.view(batch_size * seq_len, -1) + pred_rew_logits = outputs.logits_rewards.view(batch_size * seq_len, -1) + + scalar_values = inverse_scalar_transform_handle(pred_val_logits).squeeze(-1) + scalar_rewards = inverse_scalar_transform_handle(pred_rew_logits).squeeze(-1) + + self._analyze_latent_representation( + latent_states=obs_embeddings, + timesteps=batch['timestep'], + game_states=batch['observations'], + predicted_values=scalar_values, # 传入预测的Value + predicted_rewards=scalar_rewards, # 传入预测的Reward + step_counter=global_step + ) + + + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + if self.obs_type == 'image': # Reconstruct observations from latent state representations # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) @@ -1415,7 +1735,12 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dormant_ratio_world_model = cal_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, percentage=self.dormant_threshold) - self.past_kv_cache_recurrent_infer.clear() + # ==================== BUG FIX: Clear Cache Using Correct API ==================== + if self.use_new_cache_manager: + self.kv_cache_manager.clear_recur_cache() + else: + self.past_kv_cache_recurrent_infer.clear() + # ============================================================================= self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: @@ -1468,6 +1793,15 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" # for name, param in self.tokenizer.encoder.named_parameters(): # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + elif self.predict_latent_loss_type == 'cos_sim': + # --- 修复后的代码 (推荐方案) --- + # 使用余弦相似度损失 (Cosine Similarity Loss) + # F.cosine_similarity 计算的是相似度,范围是 [-1, 1]。我们希望最大化它, + # 所以最小化 1 - similarity。 + # reduction='none' 使得我们可以像原来一样处理mask + print("predict_latent_loss_type == 'cos_sim'") + cosine_sim_loss = 1 - F.cosine_similarity(logits_observations, labels_observations, dim=-1) + loss_obs = cosine_sim_loss # Apply mask to loss_obs mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) @@ -1552,6 +1886,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + # 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典 + # 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算 + detached_obs_embeddings = obs_embeddings.detach() + if self.continuous_action_space: return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, @@ -1574,6 +1912,21 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar policy_mu=mu, policy_sigma=sigma, target_sampled_actions=target_sampled_actions, + + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, + + # logits_value_mean=outputs.logits_value.mean(), + # logits_value_max=outputs.logits_value.max(), + # logits_value_min=outputs.logits_value.min(), + # logits_policy_mean=outputs.logits_policy.mean(), + # logits_policy_max=outputs.logits_policy.max(), + # logits_policy_min=outputs.logits_policy.min(), + logits_value=outputs.logits_value.detach(), # 使用detach(),因为它仅用于分析和裁剪,不参与梯度计算 + logits_reward=outputs.logits_rewards.detach(), + logits_policy=outputs.logits_policy.detach(), + ) else: return LossWithIntermediateLosses( @@ -1594,6 +1947,20 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dormant_ratio_encoder=dormant_ratio_encoder, dormant_ratio_world_model=dormant_ratio_world_model, latent_state_l2_norms=latent_state_l2_norms, + + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, + + # logits_value_mean=outputs.logits_value.mean(), + # logits_value_max=outputs.logits_value.max(), + # logits_value_min=outputs.logits_value.min(), + # logits_policy_mean=outputs.logits_policy.mean(), + # logits_policy_max=outputs.logits_policy.max(), + # logits_policy_min=outputs.logits_policy.min(), + logits_value=outputs.logits_value.detach(), # 使用detach(),因为它仅用于分析和裁剪,不参与梯度计算 + logits_reward=outputs.logits_rewards.detach(), + logits_policy=outputs.logits_policy.detach(), ) @@ -1817,11 +2184,23 @@ def clear_caches(self): """ Clears the caches of the world model. """ - for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: - kv_cache_dict_env.clear() - self.past_kv_cache_recurrent_infer.clear() - self.keys_values_wm_list.clear() - print(f'Cleared {self.__class__.__name__} past_kv_cache.') + if self.use_new_cache_manager: + # Use new KV cache manager's clear method + self.kv_cache_manager.clear_all() + print(f'Cleared {self.__class__.__name__} KV caches (NEW system).') + + # Optionally print stats before clearing + if hasattr(self.kv_cache_manager, 'get_stats_summary'): + stats = self.kv_cache_manager.get_stats_summary() + if stats.get('stats_enabled'): + logging.debug(f'Cache stats before clear: {stats}') + else: + # Use old cache clearing logic + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + print(f'Cleared {self.__class__.__name__} past_kv_cache (OLD system).') def __repr__(self) -> str: return "transformer-based latent world_model of UniZero" diff --git a/lzero/policy/random_policy.py b/lzero/policy/random_policy.py index c84806b76..3ce7e00c1 100644 --- a/lzero/policy/random_policy.py +++ b/lzero/policy/random_policy.py @@ -5,6 +5,7 @@ from ding.policy.base_policy import Policy from ding.utils import POLICY_REGISTRY +from lzero.entry.utils import initialize_zeros_batch, initialize_pad_batch from lzero.policy import DiscreteSupport, InverseScalarTransform, select_action, ez_network_output_unpack, mz_network_output_unpack @@ -31,10 +32,14 @@ def __init__( elif cfg.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree + elif cfg.type == 'unizero': + from lzero.mcts import UniZeroMCTSCtree as MCTSCtree else: raise NotImplementedError("need to implement pipeline: {}".format(cfg.type)) - self.MCTSCtree = MCTSCtree - self.MCTSPtree = MCTSPtree + if cfg.mcts_ctree: + self.MCTSCtree = MCTSCtree + else: + self.MCTSPtree = MCTSPtree self.action_space = action_space super().__init__(cfg, model, enable_field) @@ -57,6 +62,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModel', ['lzero.model.muzero_model'] elif self._cfg.type == 'sampled_efficientzero': return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] + elif self._cfg.type == 'unizero': + return 'UniZeroModel', ['lzero.model.unizero_model'] else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) elif self._cfg.model.model_type == "mlp": @@ -66,6 +73,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] elif self._cfg.type == 'sampled_efficientzero': return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] + elif self._cfg.type == 'unizero': + return 'UniZeroModel', ['lzero.model.unizero_model'] else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) @@ -85,7 +94,17 @@ def _init_collect(self) -> None: self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) - + if self._cfg.type == 'unizero': + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.full( + [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, + ).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] def _forward_collect( self, data: torch.Tensor, @@ -94,6 +113,7 @@ def _forward_collect( to_play: List = [-1], epsilon: float = 0.25, ready_env_id: np.array = None, + timestep: List = [0] ) -> Dict: """ Overview: @@ -105,30 +125,30 @@ def _forward_collect( - temperature (:obj:`float`): The temperature of the policy. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + - timestep (:obj:`list`): The step index of the env in one episode. """ self._collect_model.eval() self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon active_collect_env_num = data.shape[0] + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - network_output = self._collect_model.initial_inference(data) if self._cfg.type in ['efficientzero', 'sampled_efficientzero']: + network_output = self._collect_model.initial_inference(data) latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( network_output ) elif self._cfg.type == 'muzero': + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + elif self._cfg.type == 'unizero': + network_output = self._collect_model.initial_inference( + self.last_batch_obs, self.last_batch_action, data, timestep + ) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) @@ -143,8 +163,6 @@ def _forward_collect( policy_logits = policy_logits.detach().cpu().numpy().tolist() if self._cfg.model.continuous_action_space: - # when the action space of the environment is continuous, action_mask[:] is None. - # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_collect_env_num) ] @@ -153,20 +171,22 @@ def _forward_collect( [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num) ] - # the only difference between collect and eval is the dirichlet noise. if self._cfg.type in ['sampled_efficientzero']: noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) + np.random.dirichlet( + [self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions) + ).astype(np.float32).tolist() + for _ in range(active_collect_env_num) ] else: noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) + np.random.dirichlet( + [self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() + for j in range(active_collect_env_num) ] if self._cfg.mcts_ctree: - # cpp mcts_tree if self._cfg.type in ['sampled_efficientzero']: roots = self.MCTSCtree.roots( active_collect_env_num, legal_actions, self._cfg.model.action_space_size, @@ -175,7 +195,6 @@ def _forward_collect( else: roots = self.MCTSCtree.roots(active_collect_env_num, legal_actions) else: - # python mcts_tree if self._cfg.type in ['sampled_efficientzero']: roots = self.MCTSPtree.roots( active_collect_env_num, legal_actions, self._cfg.model.action_space_size, @@ -192,6 +211,9 @@ def _forward_collect( elif self._cfg.type == 'muzero': roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + elif self._cfg.type == 'unizero': + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) @@ -200,10 +222,6 @@ def _forward_collect( if self._cfg.type in ['sampled_efficientzero']: roots_sampled_actions = roots.get_sampled_actions() - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - output = {i: None for i in ready_env_id} - for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] @@ -249,8 +267,64 @@ def _forward_collect( 'predicted_policy_logits': policy_logits[i], } + if self._cfg.type == 'unizero': + batch_action = [output[env_id]['action'] for env_id in ready_env_id] + self.last_batch_obs = data + self.last_batch_action = batch_action + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the collection process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data + will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if self._cfg.type != 'unizero': + return + if reset_init_data: + if self._cfg.model.model_type == 'conv': + pad_token_id = -1 + else: + encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + spad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 + self.last_batch_obs = initialize_pad_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device, + pad_token_id=pad_token_id + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the collect model's world model + world_model = self._collect_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + def _init_eval(self) -> None: """ Overview: diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index 19a852f56..ce36fc459 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -110,6 +110,7 @@ def visit_count_temperature( def phi_transform( discrete_support: DiscreteSupport, x: torch.Tensor, + label_smoothing_eps: float = 0. ) -> torch.Tensor: """ Overview: @@ -163,7 +164,14 @@ def phi_transform( dtype=x.dtype, device=x.device) target.scatter_add_(-1, idx, prob) - return target + + # --- 5. 应用标签平滑 --- + if label_smoothing_eps > 0: + # 将原始的 two-hot 目标与一个均匀分布混合 + smooth_target = (1.0 - label_smoothing_eps) * target + (label_smoothing_eps / size) + return smooth_target + else: + return target def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index e42e9acf4..043a71a5b 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -1,9 +1,11 @@ import copy from collections import defaultdict from typing import List, Dict, Any, Tuple, Union - +import logging import numpy as np import torch +from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters +import torch.nn.functional as F import wandb from ding.model import model_wrap from ding.utils import POLICY_REGISTRY @@ -17,6 +19,65 @@ from lzero.policy.muzero import MuZeroPolicy from .utils import configure_optimizers_nanogpt +def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float): + """ + 使用向量化操作高效地缩放一个模块的所有权重。 + """ + if not (0.0 < scale_factor < 1.0): + return # 如果缩放因子无效,则不执行任何操作 + + # 1. 将模块的所有参数展平成一个单一向量 + params_vec = parameters_to_vector(module.parameters()) + + # 2. 在这个向量上执行一次乘法操作 + params_vec.data.mul_(scale_factor) + + # 3. 将缩放后的向量复制回模块的各个参数 + vector_to_parameters(params_vec, module.parameters()) + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + """ + # 1. 定义需要特殊处理的参数 + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + + # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads + transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} + tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + + # Heads的参数是那些既不属于transformer也不属于tokenizer的 + head_params = { + pn: p for pn, p in param_dict.items() + if 'transformer' not in pn and 'tokenizer' not in pn + } + # 3. 为每组设置不同的优化器参数(特别是学习率) + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': list(transformer_params.values()), + 'lr': learning_rate, + 'weight_decay': weight_decay + }, + { + 'params': list(tokenizer_params.values()), + 'lr': learning_rate, + 'weight_decay': weight_decay + + }, + { + 'params': list(head_params.values()), + 'lr': learning_rate, + 'weight_decay': weight_decay + } + ] + print("--- Optimizer Groups ---") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + @POLICY_REGISTRY.register('unizero') class UniZeroPolicy(MuZeroPolicy): @@ -146,6 +207,22 @@ class UniZeroPolicy(MuZeroPolicy): ), ), # ****** common ****** + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在200k次迭代后达到最终值 + # ===================== END: Encoder-Clip Annealing Config ===================== # (bool) whether to use rnd model. use_rnd_model=False, # (bool) Whether to use multi-gpu training. @@ -211,6 +288,10 @@ class UniZeroPolicy(MuZeroPolicy): optim_type='AdamW', # (float) Learning rate for training policy network. Initial lr for manually decay schedule. learning_rate=0.0001, + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + # ============================================================ # (int) Frequency of hard target network update. target_update_freq=100, # (int) Frequency of soft target network update. @@ -227,8 +308,12 @@ class UniZeroPolicy(MuZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS. + # (int) the number of simulations in MCTS for renalyze. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -313,24 +398,140 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'UniZeroModel', ['lzero.model.unizero_model'] + # ==================== [新增] 模型范数监控函数 ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件(Encoder, Transformer, Heads)的参数矩阵范数。 + 此函数应在 torch.no_grad() 环境下调用,以提高效率。 + Returns: + - norm_metrics (:obj:`Dict[str, float]`): 包含所有范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + # 计算单层参数的L2范数 + param_norm = param.data.norm(2).item() + # 替换点号,使其在TensorBoard中正确显示为层级 + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + + # 计算整个模块的总范数 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + # ================================================================= + + def _monitor_gradient_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件的梯度范数。 + 此函数应在梯度计算完成后、参数更新之前调用。 + Returns: + - grad_metrics (:obj:`Dict[str, float]`): 包含所有梯度范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + grad_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + + for param_name, param in group_module.named_parameters(): + if param.requires_grad and param.grad is not None: + # 计算单层参数的梯度L2范数 + grad_norm = param.grad.data.norm(2).item() + # 替换点号,使其在TensorBoard中正确显示为层级 + log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + + # 计算整个模块的总梯度范数 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 + + return grad_metrics + # ================================================================= + def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - # NOTE: nanoGPT optimizer - self._optimizer_world_model = configure_optimizers_nanogpt( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + if self._cfg.cos_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR # TODO: check the total training steps - self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) @@ -357,6 +558,23 @@ def _init_learn(self) -> None: self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + self.intermediate_losses = defaultdict(float) self.l2_norm_before = 0. self.l2_norm_after = 0. @@ -365,17 +583,75 @@ def _init_learn(self) -> None: if self._cfg.model.model_type == 'conv': self.pad_token_id = -1 - else: + else: encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 - if self._cfg.use_wandb: # TODO: add the model to wandb wandb.watch(self._learn_model.representation_network, log="all") self.accumulation_steps = self._cfg.accumulation_steps + + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + + # ===================== END: 目标熵正则化初始化 ===================== + + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ @@ -392,10 +668,20 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in """ self._learn_model.train() self._target_model.train() + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.train() current_batch, target_batch, train_iter = data obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch + + # --- NEW: Calculate current epsilon for policy --- + if self.policy_ls_eps_start > 0: + progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + else: + current_policy_label_eps = 0.0 + # Prepare observations based on frame stack number if self._cfg.model.frame_stack_num > 1: @@ -425,8 +711,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in transformed_target_value = scalar_transform(target_value) # Convert to categorical distributions - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps= self._cfg.label_smoothing_eps) # Prepare batch for GPT model batch_for_gpt = {} @@ -448,6 +734,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in device=self._cfg.device) batch_for_gpt['target_value'] = target_value_categorical[:, :-1] batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value # Extract valid target policy data and compute entropy valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] @@ -456,10 +743,75 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, global_step=train_iter, current_policy_label_eps=current_policy_label_eps, ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. - weighted_total_loss = losses.loss_total + # ==================== [修改] 集成范数监控逻辑 ==================== + norm_log_dict = {} + # 检查是否达到监控频率 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + with torch.no_grad(): + # 1. 监控模型参数范数 + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. 监控中间张量 x (Transformer的输出) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x 的形状为 (B, T, E) + # 计算每个 token 的 L2 范数 + token_norms = intermediate_x.norm(p=2, dim=-1) + + # 记录这些范数的统计数据 + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + + # 3. 监控 logits 的详细统计 (Value, Policy, Reward) + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + norm_log_dict['logits/value/mean'] = logits_value.mean().item() + norm_log_dict['logits/value/std'] = logits_value.std().item() + norm_log_dict['logits/value/max'] = logits_value.max().item() + norm_log_dict['logits/value/min'] = logits_value.min().item() + norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() + + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() + norm_log_dict['logits/policy/std'] = logits_policy.std().item() + norm_log_dict['logits/policy/max'] = logits_policy.max().item() + norm_log_dict['logits/policy/min'] = logits_policy.min().item() + norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() + norm_log_dict['logits/reward/std'] = logits_reward.std().item() + norm_log_dict['logits/reward/max'] = logits_reward.max().item() + norm_log_dict['logits/reward/min'] = logits_reward.min().item() + norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() + + # 4. 监控 obs_embeddings (Encoder输出) 的统计 + obs_embeddings = losses.intermediate_losses.get('obs_embeddings') + if obs_embeddings is not None: + # 计算每个 embedding 的 L2 范数 + emb_norms = obs_embeddings.norm(p=2, dim=-1) + norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() + norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() + norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() + norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() + # ================================================================= + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + weighted_total_loss = (weights * losses.loss_total).mean() + for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value @@ -477,6 +829,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] + + latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -485,13 +840,94 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Reset gradients at the start of each accumulation cycle if (train_iter % self.accumulation_steps) == 0: self._optimizer_world_model.zero_grad() + + # ==================== START: 目标熵正则化更新逻辑 ==================== + alpha_loss = None + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() + + # # --- 更新 log_alpha --- + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + self.log_alpha.clamp_(np.log(5e-3), np.log(10.0)) + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + total_loss = ( + losses.reward_loss_weight * reward_loss + + losses.value_loss_weight * value_loss + + losses.policy_loss_weight * weighted_policy_loss + + losses.obs_loss_weight * obs_loss + + losses.latent_recon_loss_weight * latent_recon_loss + + losses.perceptual_loss_weight * perceptual_loss + ) + weighted_total_loss = (weights * total_loss).mean() + # ===================== END: 目标熵正则化更新逻辑 ===================== # Scale the loss by the number of accumulation steps weighted_total_loss = weighted_total_loss / self.accumulation_steps weighted_total_loss.backward() + + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) # Check if the current iteration completes an accumulation cycle if (train_iter + 1) % self.accumulation_steps == 0: + # ==================== [新增] 监控梯度范数 ==================== + # 在梯度裁剪之前监控梯度范数,用于诊断梯度爆炸/消失问题 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + grad_norm_metrics = self._monitor_gradient_norms() + norm_log_dict.update(grad_norm_metrics) + # ================================================================= # Analyze gradient norms if simulation normalization analysis is enabled if self._cfg.analysis_sim_norm: # Clear previous analysis results to prevent memory overflow @@ -523,6 +959,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update the target model with the current model's parameters self._target_model.update(self._learn_model.state_dict()) + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) if torch.cuda.is_available(): torch.cuda.synchronize() @@ -565,7 +1003,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'target_policy_entropy': average_target_policy_entropy.item(), 'reward_loss': reward_loss.item(), 'value_loss': value_loss.item(), - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # Add value_priority to the log dictionary. + 'value_priority': value_priority_np.mean().item(), + 'value_priority_orig': value_priority_np, 'target_reward': target_reward.mean().item(), 'target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), @@ -574,15 +1014,33 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/latent_action_l2_norms': latent_action_l2_norms, 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, + + "current_policy_label_eps":current_policy_label_eps, } + # ==================== [修改] 将范数监控结果合并到日志中 ==================== + if norm_log_dict: + return_log_dict.update(norm_log_dict) + # ======================================================================= + # ==================== START: 添加新日志项 ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['alpha_loss'] = alpha_loss.item() + # ==================== START: 添加新日志项 ==================== + # ==================== START: 添加新日志项 ==================== + if self.use_encoder_clip_annealing: + return_log_dict['current_encoder_clip_value'] = current_clip_value + # ===================== END: 添加新日志项 ===================== - if self._cfg.use_wandb: - wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) - wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + if self._cfg.use_wandb and self._rank == 0: + for k, v in return_log_dict.items(): + wandb.log({'learner_step/' + k: v , "env_step": self.env_step}) + wandb.log({'learner_iter/' + k: v , "train_iter": self.train_iter}) return return_log_dict @@ -601,11 +1059,13 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model - + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) self._collect_mcts_temperature = 1. self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num @@ -761,10 +1221,13 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': @@ -894,29 +1357,46 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if hasattr(world_model, 'use_new_cache_manager') and world_model.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager to clear per-environment cache + if eid < world_model.env_num: + world_model.kv_cache_manager.init_pools[eid].clear() + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end (NEW system).') + else: + # OLD SYSTEM: Use legacy cache dictionary + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end (OLD system).') # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the collect model's world model world_model = self._collect_model.world_model - for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: - kv_cache_dict_env.clear() - world_model.past_kv_cache_recurrent_infer.clear() - world_model.keys_values_wm_list.clear() + # ==================== Phase 1.5: Use unified clear_caches() method ==================== + # This automatically handles both old and new cache systems + world_model.clear_caches() + # ====================================================================================== # Free up GPU memory torch.cuda.empty_cache() - - print('collector: collect_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + print(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()') def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: """ @@ -939,23 +1419,57 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if hasattr(world_model, 'use_new_cache_manager') and world_model.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager to clear per-environment cache + if eid < world_model.env_num: + world_model.kv_cache_manager.init_pools[eid].clear() + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end (NEW system).') + else: + # OLD SYSTEM: Use legacy cache dictionary + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end (OLD system).') + # ============================================================================= + + + # The recurrent cache is global. + # ==================== Phase 1.5: Use unified clear_caches() method ==================== + # This automatically handles both old and new cache systems + world_model.clear_caches() + # ====================================================================================== + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() - # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the eval model's world model world_model = self._eval_model.world_model - for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: - kv_cache_dict_env.clear() - world_model.past_kv_cache_recurrent_infer.clear() - world_model.keys_values_wm_list.clear() + # ==================== Phase 1.5: Use unified clear_caches() method ==================== + # This automatically handles both old and new cache systems + world_model.clear_caches() + # ====================================================================================== # Free up GPU memory torch.cuda.empty_cache() @@ -969,15 +1483,17 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - return [ + base_vars = [ 'analysis/dormant_ratio_encoder', 'analysis/dormant_ratio_world_model', 'analysis/latent_state_l2_norms', + 'analysis/latent_action_l2_norms', 'analysis/l2_norm_before', 'analysis/l2_norm_after', 'analysis/grad_norm_before', 'analysis/grad_norm_after', + # ==================== Step-wise Loss Analysis ==================== 'analysis/first_step_loss_value', 'analysis/first_step_loss_policy', 'analysis/first_step_loss_rewards', @@ -992,34 +1508,98 @@ def _monitor_vars_learn(self) -> List[str]: 'analysis/last_step_loss_policy', 'analysis/last_step_loss_rewards', 'analysis/last_step_loss_obs', - + + # ==================== System Metrics ==================== 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', 'cur_lr_world_model', - 'cur_lr_tokenizer', + # ==================== Core Losses ==================== 'weighted_total_loss', 'obs_loss', 'policy_loss', 'orig_policy_loss', 'policy_entropy', 'latent_recon_loss', + 'perceptual_loss', 'target_policy_entropy', 'reward_loss', 'value_loss', - 'consistency_loss', 'value_priority', 'target_reward', 'target_value', + 'transformed_target_reward', + 'transformed_target_value', + + # ==================== Gradient Norms ==================== 'total_grad_norm_before_clip_wm', - # tokenizer - 'commitment_loss', - 'reconstruction_loss', - 'perceptual_loss', + + # ==================== Temperature Parameters ==================== + 'temperature_value', + 'temperature_reward', + 'temperature_policy', + + # ==================== Training Configuration ==================== + 'current_policy_label_eps', + 'adaptive_alpha', + 'adaptive_target_entropy_ratio', + 'alpha_loss', + "current_encoder_clip_value", ] + # ==================== [新增] 范数和中间张量监控变量 ==================== + norm_vars = [ + # 模块总范数 (参数范数) + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + 'norm/head_value/_total_norm', + 'norm/head_reward/_total_norm', + 'norm/head_policy/_total_norm', + # 模块总范数 (梯度范数) + 'grad/encoder/_total_norm', + 'grad/transformer/_total_norm', + 'grad/head_value/_total_norm', + 'grad/head_reward/_total_norm', + 'grad/head_policy/_total_norm', + + # 中间张量 x (Transformer输出) 的统计信息 + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', + + # Logits 的详细统计 (Value) + 'logits/value/mean', + 'logits/value/std', + 'logits/value/max', + 'logits/value/min', + 'logits/value/abs_max', + + # Logits 的详细统计 (Policy) + 'logits/policy/mean', + 'logits/policy/std', + 'logits/policy/max', + 'logits/policy/min', + 'logits/policy/abs_max', + + # Logits 的详细统计 (Reward) + 'logits/reward/mean', + 'logits/reward/std', + 'logits/reward/max', + 'logits/reward/min', + 'logits/reward/abs_max', + + # Embeddings 的统计信息 + 'embeddings/obs/norm_mean', + 'embeddings/obs/norm_std', + 'embeddings/obs/norm_max', + 'embeddings/obs/norm_min', + ] + return base_vars + norm_vars + + def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: @@ -1027,11 +1607,16 @@ def _state_dict_learn(self) -> Dict[str, Any]: Returns: - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ - return { + state_dict = { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), 'optimizer_world_model': self._optimizer_world_model.state_dict(), } + # ==================== START: 保存Alpha优化器状态 ==================== + if self.use_adaptive_entropy_weight: + state_dict['alpha_optimizer'] = self.alpha_optimizer.state_dict() + # ===================== END: 保存Alpha优化器状态 ===================== + return state_dict def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ @@ -1042,7 +1627,12 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ==================== START: 加载Alpha优化器状态 ==================== + # if self.use_adaptive_entropy_weight and 'alpha_optimizer' in state_dict: + # self.alpha_optimizer.load_state_dict(state_dict['alpha_optimizer']) + # ===================== END: 加载Alpha优化器状态 ===================== def recompute_pos_emb_diff_and_clear_cache(self) -> None: """ diff --git a/lzero/reward_model/rnd_reward_model.py b/lzero/reward_model/rnd_reward_model.py index 453e63759..ebb417cf4 100644 --- a/lzero/reward_model/rnd_reward_model.py +++ b/lzero/reward_model/rnd_reward_model.py @@ -1,6 +1,9 @@ +import logging import copy import random -from typing import Union, Tuple, List, Dict +from collections import defaultdict +from typing import Union, Tuple, List, Dict, Optional +import wandb import numpy as np import torch @@ -12,26 +15,63 @@ from ding.utils import RunningMeanStd from ding.utils import SequenceType, REWARD_MODEL_REGISTRY from easydict import EasyDict +from ding.utils import get_rank, get_world_size, build_logger, allreduce, synchronize +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from copy import deepcopy class RNDNetwork(nn.Module): - def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None: + def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType, output_dim: int = 512,activation_type: str = "ReLU", kernel_size_list=[8,4,3], stride_size_list=[4,2,1]) -> None: super(RNDNetwork, self).__init__() + assert len(hidden_size_list) >= 1, "hidden_size_list must contain at least one element." + feature_dim = hidden_size_list[-1] + if activation_type == "ReLU": + self.activation = nn.ReLU() + elif activation_type == "LeakyReLU": + self.activation = nn.LeakyReLU() + else: + raise KeyError("not support activation_type for RND model: {}, please customize your own RND model".format(activation_type)) + if isinstance(obs_shape, int) or len(obs_shape) == 1: - self.target = FCEncoder(obs_shape, hidden_size_list) - self.predictor = FCEncoder(obs_shape, hidden_size_list) + target_backbone = FCEncoder(obs_shape, hidden_size_list, activation=self.activation) + predictor_backbone = FCEncoder(obs_shape, hidden_size_list, activation=self.activation) elif len(obs_shape) == 3: - self.target = ConvEncoder(obs_shape, hidden_size_list) - self.predictor = ConvEncoder(obs_shape, hidden_size_list) + target_backbone = [] + predictor_backbone = [] + input_size = obs_shape[0] + for i in range(len(hidden_size_list)): + target_backbone.append(nn.Conv2d(input_size , hidden_size_list[i], kernel_size_list[i], stride_size_list[i])) + target_backbone.append(self.activation) + + predictor_backbone.append(nn.Conv2d(input_size , hidden_size_list[i], kernel_size_list[i], stride_size_list[i])) + predictor_backbone.append(self.activation) + input_size = hidden_size_list[i] + target_backbone.append(nn.Flatten()) + predictor_backbone.append(nn.Flatten()) + + self.target = nn.Sequential( + *target_backbone, + nn.LazyLinear(output_dim) + ) + self.predictor = nn.Sequential( + *predictor_backbone, + nn.LazyLinear(512), + nn.ReLU(), + nn.LazyLinear(512), + nn.ReLU(), + nn.LazyLinear(output_dim) + ) else: raise KeyError( "not support obs_shape for pre-defined encoder: {}, please customize your own RND model". format(obs_shape) - ) + ) for param in self.target.parameters(): param.requires_grad = False - + def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: predict_feature = self.predictor(obs) with torch.no_grad(): @@ -49,17 +89,38 @@ def __init__(self, obs_shape: Union[int, SequenceType], latent_shape: Union[int, representation_network) -> None: super(RNDNetworkRepr, self).__init__() self.representation_network = representation_network + assert len(hidden_size_list) >= 1, "hidden_size_list must contain at least one element." + feature_dim = hidden_size_list[-1] + activation = nn.ReLU() + if isinstance(obs_shape, int) or len(obs_shape) == 1: - self.target = FCEncoder(obs_shape, hidden_size_list) - self.predictor = FCEncoder(latent_shape, hidden_size_list) + target_backbone = FCEncoder(obs_shape, hidden_size_list) elif len(obs_shape) == 3: - self.target = ConvEncoder(obs_shape, hidden_size_list) - self.predictor = ConvEncoder(latent_shape, hidden_size_list) + target_backbone = ConvEncoder(obs_shape, hidden_size_list) else: raise KeyError( "not support obs_shape for pre-defined encoder: {}, please customize your own RND model". format(obs_shape) ) + + if isinstance(latent_shape, int) or (isinstance(latent_shape, SequenceType) and len(latent_shape) == 1): + predictor_backbone = FCEncoder(latent_shape, hidden_size_list) + elif isinstance(latent_shape, SequenceType) and len(latent_shape) == 3: + predictor_backbone = ConvEncoder(latent_shape, hidden_size_list) + else: + raise KeyError( + "not support latent_shape for pre-defined encoder: {}, please customize your own RND model". + format(latent_shape) + ) + + self.target = nn.Sequential(target_backbone, activation) + self.predictor = nn.Sequential( + predictor_backbone, + activation, + nn.Linear(feature_dim, feature_dim), + activation, + nn.Linear(feature_dim, feature_dim), + ) for param in self.target.parameters(): param.requires_grad = False @@ -106,71 +167,125 @@ class RNDRewardModel(BaseRewardModel): ``reward_norm_max`` | normalization == ==================== ===== ============= ======================================= ======================= """ - config = dict( + rnd_config = dict( # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. type='rnd', # (str) The intrinsic reward type, including add, new, or assign. intrinsic_reward_type='add', # (float) The step size of gradient descent. learning_rate=1e-3, - # (float) Batch size. - batch_size=64, # (list(int)) Sequence of ``hidden_size`` of reward network. # If obs.shape == 1, use MLP layers. # If obs.shape == 3, use conv layer and final dense layer. hidden_size_list=[64, 64, 128], # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - update_per_collect=100, # (bool) Observation normalization: transform obs to mean 0, std 1. input_norm=True, # (int) Min clip value for observation normalization. - input_norm_clamp_min=-1, + input_norm_clamp_min=-5, # (int) Max clip value for observation normalization. - input_norm_clamp_max=1, + input_norm_clamp_max=5, # Means the relative weight of RND intrinsic_reward. # (float) The weight of intrinsic reward # r = intrinsic_reward_weight * r_i + r_e. - intrinsic_reward_weight=0.01, - # (bool) Whether to normalize extrinsic reward. - # Normalize the reward to [0, extrinsic_reward_norm_max]. - extrinsic_reward_norm=True, - # (int) The upper bound of the reward normalization. - extrinsic_reward_norm_max=1, + # (bool) Whether to normalize extrinsic reward using running statistics. + # (bool) Whether to adjust target value with intrinsic reward contribution. + adjust_value_with_intrinsic=False, + # (float) Discount factor used when adjusting target value. + discount_factor=1.0, + # 新增:图片日志总开关与可视化参数 + enable_image_logging=False, # ← 总开关:是否在TB输出图片(时间线+关键帧等) + peaks_topk=12, # 关键帧个数 + + # —— 新增:自适应权重调度 —— # + use_intrinsic_weight_schedule=True, # 打开自适应权重 + intrinsic_weight_mode='cosine', # 'cosine' | 'linear' | 'constant' + intrinsic_weight_warmup=1000, # 前多少次 estimate 权重=0 + intrinsic_weight_ramp=5000, # 从0升到max所需的 estimate 数 + intrinsic_weight_min=0.0, + intrinsic_weight_max=0.02, + ) - def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None, - representation_network: nn.Module = None, target_representation_network: nn.Module = None, - use_momentum_representation_network: bool = True) -> None: # noqa + def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None, exp_name: str = "default_experiment", + instance_name: str = 'RNDModel', representation_network: nn.Module = None, + target_representation_network: nn.Module = None, use_momentum_representation_network: bool = True, + bp_update_sync: bool = True, multi_gpu: bool = False) -> None: # noqa super(RNDRewardModel, self).__init__() - self.cfg = config + self.cfg = EasyDict(deepcopy(RNDRewardModel.rnd_config)) + self.cfg.update(config) self.representation_network = representation_network self.target_representation_network = target_representation_network self.use_momentum_representation_network = use_momentum_representation_network self.input_type = self.cfg.input_type assert self.input_type in ['obs', 'latent_state', 'obs_latent_state'], self.input_type - self.device = device - assert self.device == "cpu" or self.device.startswith("cuda") - self.rnd_buffer_size = config.rnd_buffer_size self.intrinsic_reward_type = self.cfg.intrinsic_reward_type - if tb_logger is None: - from tensorboardX import SummaryWriter - tb_logger = SummaryWriter('rnd_reward_model') - self.tb_logger = tb_logger + self.adjust_value_with_intrinsic = getattr(self.cfg, 'adjust_value_with_intrinsic', False) + self.discount_factor = getattr(self.cfg, 'discount_factor', 1.0) + + self._exp_name = exp_name + self._instance_name = instance_name + self._rank = get_rank() + self._world_size = get_world_size() + self.multi_gpu = multi_gpu + self._bp_update_sync = bp_update_sync + self._device = device + self.activation_type = getattr(self.cfg, 'activation_type', 'ReLU') + self.enable_image_logging = bool(getattr(self.cfg, 'enable_image_logging', False)) + self.use_intrinsic_weight_schedule = bool(getattr(self.cfg, 'use_intrinsic_weight_schedule', False)) + + if self.multi_gpu: + self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if 'cuda' in device else 'cpu' + else: + self._device = device + + if self._rank == 0: + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), + name=self._instance_name, + need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = None + else: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = None + + self._logger.info( + "[RND] device=%s | input_type=%s | hidden=%s", + self._device, self.input_type, str(self.cfg.hidden_size_list) + ) + if self.use_intrinsic_weight_schedule: + self._logger.info( + "[RND] intrinsic weight schedule: ENABLED | mode=%s | warmup=%d | ramp=%d | min=%.3f | max=%.3f", + self.cfg.intrinsic_weight_mode, self.cfg.intrinsic_weight_warmup, self.cfg.intrinsic_weight_ramp, + self.cfg.intrinsic_weight_min, self.cfg.intrinsic_weight_max + ) + else: + self._logger.info( + "[RND] intrinsic weight schedule: disabled | fixed_weight=%.3f", self.cfg.intrinsic_weight_max + ) + if self.input_type == 'obs': self.input_shape = self.cfg.obs_shape - self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device) + self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list, activation_type=self.activation_type).to(self._device) elif self.input_type == 'latent_state': self.input_shape = self.cfg.latent_state_dim - self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device) + self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list, activation_type=self.activation_type).to(self._device) elif self.input_type == 'obs_latent_state': if self.use_momentum_representation_network: self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1], - self.target_representation_network).to(self.device) + self.target_representation_network, activation_type=self.activation_type).to(self._device) else: self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1], - self.representation_network).to(self.device) + self.representation_network, activation_type=self.activation_type).to(self._device) assert self.intrinsic_reward_type in ['add', 'new', 'assign'] if self.input_type in ['obs', 'obs_latent_state']: @@ -184,46 +299,286 @@ def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWri self._running_mean_std_rnd_reward = RunningMeanStd(epsilon=1e-4) self._running_mean_std_rnd_obs = RunningMeanStd(epsilon=1e-4) + self._running_mean_std_reward = RunningMeanStd(epsilon=1e-4) + self._obs_shape_tuple = self._resolve_obs_shape() self.estimate_cnt_rnd = 0 self.train_cnt_rnd = 0 + self._state_visit_counts = defaultdict(int) + self._initial_reward_samples: List[np.ndarray] = [] + self._initial_consistency_logged = False - def _train_with_data_one_step(self) -> None: - if self.input_type in ['obs', 'obs_latent_state']: - train_data = random.sample(self.train_obs, self.cfg.batch_size) - elif self.input_type == 'latent_state': - train_data = random.sample(self.train_latent_state, self.cfg.batch_size) + def _resolve_obs_shape(self) -> Tuple[int, ...]: + """ + Overview: + Convert the configured observation shape to a tuple for downstream processing. + """ + obs_shape = self.cfg.obs_shape + if isinstance(obs_shape, int): + return (obs_shape,) + if isinstance(obs_shape, (list, tuple)): + return tuple(obs_shape) + raise TypeError(f"Unsupported obs_shape type for RND: {type(obs_shape)}") + + def _flatten_obs_batch(self, obs_batch: np.ndarray) -> np.ndarray: + """ + Overview: + Flatten time/batch dimensions while keeping the per-observation shape intact. + """ + if not isinstance(obs_batch, np.ndarray): + obs_batch = np.asarray(obs_batch) + feature_size = int(np.prod(self._obs_shape_tuple)) + total = obs_batch.size // feature_size + target_shape = (total,) + self._obs_shape_tuple + return obs_batch.reshape(target_shape) + + def _prepare_inputs_from_obs(self, obs_array: np.ndarray) -> torch.Tensor: + """ + Overview: + Convert raw observations into tensors that can be consumed by the RND networks + according to the configured input type. + """ + obs_tensor = to_tensor(obs_array).to(self._device) + if self.input_type == 'latent_state': + with torch.no_grad(): + inputs = self.representation_network(obs_tensor) + else: + inputs = obs_tensor + return inputs - train_data = torch.stack(train_data).to(self.device) + def _update_input_running_stats(self, tensor: torch.Tensor) -> None: + """ + Overview: + Update running mean/std for input normalization using the provided tensor. + """ + if not self.cfg.input_norm or tensor.numel() == 0: + return + self._running_mean_std_rnd_obs.update(tensor.detach().cpu().numpy()) + + def _normalize_intrinsic_rewards(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Overview: + Normalize intrinsic rewards with the running std statistics. + """ + if getattr(self.cfg, 'intrinsic_norm', False): + std = to_tensor(self._running_mean_std_rnd_reward.std).to(self._device) + std = torch.clamp(std, min=1e-6) + normalized = tensor / std + return torch.clamp( + normalized, + min=getattr(self.cfg, 'intrinsic_norm_clamp_min', -5), + max=getattr(self.cfg, 'intrinsic_norm_clamp_max', 5) + ) + return tensor + + def _normalize_inputs(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Overview: + Normalize inputs with the running mean/std statistics (intrinsic input normalization). + """ + if not self.cfg.input_norm or tensor.numel() == 0: + return tensor if self.cfg.input_norm: - # Note: observation normalization: transform obs to mean 0, std 1 - self._running_mean_std_rnd_obs.update(train_data.detach().cpu().numpy()) - normalized_train_data = (train_data - to_tensor(self._running_mean_std_rnd_obs.mean).to( - self.device)) / to_tensor( - self._running_mean_std_rnd_obs.std - ).to(self.device) - train_data = torch.clamp(normalized_train_data, min=self.cfg.input_norm_clamp_min, - max=self.cfg.input_norm_clamp_max) - - predict_feature, target_feature = self.reward_model(train_data) - loss = F.mse_loss(predict_feature, target_feature) + mean = to_tensor(self._running_mean_std_rnd_obs.mean).to(self._device) + std = to_tensor(self._running_mean_std_rnd_obs.std).to(self._device) + std = torch.clamp(std, min=1e-6) + normalized = (tensor - mean) / std + return torch.clamp(normalized, min=self.cfg.input_norm_clamp_min, max=self.cfg.input_norm_clamp_max) + return tensor + + def _normalize_rewards(self, rewards: np.ndarray) -> np.ndarray: + """ + Overview: + Normalize extrinsic rewards using running statistics when enabled. + """ + if rewards.size == 0: + return rewards + normalized = np.asarray(rewards, dtype=np.float32) + if getattr(self.cfg, 'extrinsic_norm', False): + self._running_mean_std_reward.update(normalized) + mean = np.asarray(self._running_mean_std_reward.mean, dtype=np.float32) + std = np.asarray(self._running_mean_std_reward.std, dtype=np.float32) + 1e-6 + normalized = (normalized - mean) / std + normalized = np.clip( + normalized, + a_min=getattr(self.cfg, 'extrinsic_norm_clamp_min', -5), + a_max=getattr(self.cfg, 'extrinsic_norm_clamp_max', 5) + ) + elif getattr(self.cfg, 'extrinsic_sign', False): + normalized = np.sign(normalized) + return normalized + + def _hash_obs(self, obs: np.ndarray) -> int: + return hash(obs.tobytes()) + + def _update_visit_counts(self, obs_array: np.ndarray) -> None: + if obs_array.size == 0: + return + flat = obs_array.reshape(obs_array.shape[0], -1) + for obs in flat: + self._state_visit_counts[self._hash_obs(obs)] += 1 + + def _spearmanr(self, x: np.ndarray, y: np.ndarray) -> float: + if x.size < 2 or y.size < 2: + return 0.0 + x_rank = np.argsort(np.argsort(x)) + y_rank = np.argsort(np.argsort(y)) + x_rank = x_rank.astype(np.float32) + y_rank = y_rank.astype(np.float32) + x_rank -= x_rank.mean() + y_rank -= y_rank.mean() + denom = np.linalg.norm(x_rank) * np.linalg.norm(y_rank) + if denom == 0: + return 0.0 + return float(np.dot(x_rank, y_rank) / denom) + + def _log_initial_bonus_consistency(self) -> None: + if self._initial_consistency_logged or not self._initial_reward_samples: + return + rewards = np.concatenate(self._initial_reward_samples, axis=0) + if rewards.size == 0: + return + rewards = rewards - rewards.min() + rewards = rewards + 1e-8 + p = rewards / rewards.sum() + kl = float(np.sum(p * np.log(p * len(p)))) + if self._tb_logger: + self._tb_logger.add_scalar('rnd_metrics/bcs_initial_kl', kl, 0) + self._initial_consistency_logged = True + self._initial_reward_samples = [] + + def _log_final_metrics(self, intrinsic_rewards: np.ndarray, obs_array: np.ndarray, step: int) -> None: + if intrinsic_rewards.size == 0 or obs_array.size == 0: + return + intrinsic_flat = intrinsic_rewards.reshape(-1) + flat_obs = obs_array.reshape(obs_array.shape[0], -1) + hashes = [self._hash_obs(obs) for obs in flat_obs] + counts = np.array([max(self._state_visit_counts.get(h, 1), 1) for h in hashes], dtype=np.float32) + inv_counts = 1.0 / (counts + 1e-6) + bcs_final = self._spearmanr(intrinsic_flat, inv_counts) + pca_spearman = bcs_final + if self._tb_logger: + self._tb_logger.add_scalar('rnd_metrics/bcs_final_spearman', bcs_final, step) + self._tb_logger.add_scalar('rnd_metrics/pca_spearman', pca_spearman, step) + + def _discount_cumsum(self, rewards: np.ndarray, gamma: float) -> np.ndarray: + if rewards.ndim != 2: + rewards = rewards.reshape(rewards.shape[0], -1) + discounted = np.zeros_like(rewards, dtype=np.float32) + if rewards.shape[1] == 0: + return discounted + discounted[:, -1] = rewards[:, -1] + for t in range(rewards.shape[1] - 2, -1, -1): + discounted[:, t] = rewards[:, t] + gamma * discounted[:, t + 1] + return discounted + + def warmup_with_random_segments(self, data: list) -> None: + """ + Overview: + Use randomly collected segments to bootstrap the input normalization statistics + before the main training loop starts. + """ + if data is None or len(data) == 0: + return + segments = [game_segment.obs_segment for game_segment in data[0] if len(game_segment.obs_segment)] + if not segments: + return + concatenated = np.concatenate(segments, axis=0) + flattened = self._flatten_obs_batch(concatenated) + total = flattened.shape[0] + self._update_visit_counts(flattened) + inputs = self._prepare_inputs_from_obs(flattened) + self._update_input_running_stats(inputs) + inputs = self._normalize_inputs(inputs.clone()) + with torch.no_grad(): + predict_feature, target_feature = self.reward_model(inputs) + mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1).detach().cpu().numpy() + if mse.size > 0: + self._initial_reward_samples.append(mse) + self._log_initial_bonus_consistency() + + + def sync_gradients(self, model: torch.nn.Module) -> None: + """ + Overview: + Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training. + Arguments: + - model (:obj:`torch.nn.Module`): The model to synchronize gradients. + + .. note:: + This method is only used in multi-gpu training, and it should be called after ``backward`` method and \ + before ``step`` method. The user can also use ``bp_update_sync`` config to control whether to synchronize \ + gradients allreduce and optimizer updates. + """ + + if self._bp_update_sync: + for name, param in model.named_parameters(): + if param.requires_grad: + if param.grad is not None: + allreduce(param.grad.data) + else: + zero_grad = torch.zeros_like(param.data) + allreduce(zero_grad) + else: + synchronize() - self.tb_logger.add_scalar('rnd_reward_model/rnd_mse_loss', loss, self.train_cnt_rnd) + + def train_with_policy_batch(self, data: list, gradient_steps: Optional[int] = None) -> None: + """ + Overview: + Update the RND predictor using the same batch of transitions sampled for the policy learner. + """ + if data is None or len(data) == 0: + return + obs_batch = data[0][0] + flat_obs = self._flatten_obs_batch(obs_batch) + prepared_inputs = self._prepare_inputs_from_obs(flat_obs) + if prepared_inputs.numel() == 0: + return + self._update_input_running_stats(prepared_inputs) + normalized_input = self._normalize_inputs(prepared_inputs) + predict_feature, target_feature = self.reward_model(normalized_input) + loss = F.mse_loss(predict_feature, target_feature) + if self._tb_logger: + self._tb_logger.add_scalar('rnd_reward_model/rnd_mse_loss', loss, self.train_cnt_rnd) self._optimizer_rnd.zero_grad() loss.backward() + if self.multi_gpu: + self.sync_gradients(self.reward_model.predictor) self._optimizer_rnd.step() + self.train_cnt_rnd += 1 - def train_with_data(self) -> None: - for _ in range(self.cfg.update_per_collect): - # for name, param in self.reward_model.named_parameters(): - # if param.grad is not None: - # print(f"{name}: {torch.isnan(param.grad).any()}, {torch.isinf(param.grad).any()}") - # print(f"{name}: grad min: {param.grad.min()}, grad max: {param.grad.max()}") - # # enable the following line to check whether there is nan or inf in the gradient. - # torch.autograd.set_detect_anomaly(True) - self._train_with_data_one_step() - self.train_cnt_rnd += 1 + def _intrinsic_weight(self, step: int) -> float: + """ + 根据当前 estimate 步数返回 RND 权重: + - step < warmup → 0 + - warmup 之后 → 线性/余弦从 min 升到 max + """ + if not self.cfg.use_intrinsic_weight_schedule: + return float(self.cfg.intrinsic_weight_max) + + wmin = float(self.cfg.intrinsic_weight_min) + wmax = float(self.cfg.intrinsic_weight_max) + warmup = int(self.cfg.intrinsic_weight_warmup) + ramp = max(1, int(self.cfg.intrinsic_weight_ramp)) + mode = str(self.cfg.intrinsic_weight_mode).lower() + + if step <= warmup: + return 0.0 + + # 归一化进度 p ∈ [0,1] + t = min(max(step - warmup, 0), ramp) + p = t / float(ramp) + + if mode == 'linear': + w = wmin + (wmax - wmin) * p + elif mode == 'cosine': + w = wmin + 0.5 * (wmax - wmin) * (1.0 - np.cos(np.pi * p)) + else: + w = float(self.cfg.intrinsic_weight_max) + return float(w) + def estimate(self, data: list) -> List[Dict]: """ Rewrite the reward key in each row of the data. @@ -234,96 +589,93 @@ def estimate(self, data: list) -> List[Dict]: obs_batch_orig = data[0][0] target_reward = data[1][0] batch_size = obs_batch_orig.shape[0] - # reshape to (4, 2835, 6) - obs_batch_tmp = np.reshape(obs_batch_orig, (batch_size, self.cfg.obs_shape, 6)) - # reshape to (24, 2835) - obs_batch_tmp = np.reshape(obs_batch_tmp, (batch_size * 6, self.cfg.obs_shape)) + T = target_reward.shape[1] + logging.info("[RND] estimate enter, batch=%s, horizon=%s", batch_size, T) - if self.input_type == 'latent_state': - with torch.no_grad(): - latent_state = self.representation_network(torch.from_numpy(obs_batch_tmp).to(self.device)) - input_data = latent_state - elif self.input_type in ['obs', 'obs_latent_state']: - input_data = to_tensor(obs_batch_tmp).to(self.device) + obs_batch_tmp = self._flatten_obs_batch(obs_batch_orig) + self._update_visit_counts(obs_batch_tmp) + input_data = self._prepare_inputs_from_obs(obs_batch_tmp) # NOTE: deepcopy reward part of data is very important, - # otherwise the reward of data in the replay buffer will be incorrectly modified. - target_reward_augmented = copy.deepcopy(target_reward) - target_reward_augmented = np.reshape(target_reward_augmented, (batch_size * 6, 1)) - - if self.cfg.input_norm: - # add this line to avoid inplace operation on the original tensor. - input_data = input_data.clone() - # Note: observation normalization: transform obs to mean 0, std 1 - input_data = (input_data - to_tensor(self._running_mean_std_rnd_obs.mean - ).to(self.device)) / to_tensor(self._running_mean_std_rnd_obs.std).to( - self.device) - input_data = torch.clamp(input_data, min=self.cfg.input_norm_clamp_min, max=self.cfg.input_norm_clamp_max) - else: - input_data = input_data + original_reward = np.reshape(np.array(target_reward, dtype=np.float32), (batch_size * T, 1)) + input_data = self._normalize_inputs(input_data.clone()) + extrinsic_normalized = self._normalize_rewards(original_reward) with torch.no_grad(): predict_feature, target_feature = self.reward_model(input_data) mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1) - self._running_mean_std_rnd_reward.update(mse.detach().cpu().numpy()) - - # Note: according to the min-max normalization, transform rnd reward to [0,1] - rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-6) - - # save the rnd_reward statistics into tb_logger - self.estimate_cnt_rnd += 1 - self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_max', rnd_reward.max(), self.estimate_cnt_rnd) - self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_mean', rnd_reward.mean(), self.estimate_cnt_rnd) - self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_min', rnd_reward.min(), self.estimate_cnt_rnd) - self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_std', rnd_reward.std(), self.estimate_cnt_rnd) - - rnd_reward = rnd_reward.to(self.device).unsqueeze(1).cpu().numpy() + mse_np = mse.detach().cpu().numpy() + self._running_mean_std_rnd_reward.update(mse_np) + mse_tensor = torch.from_numpy(mse_np).to(self._device) + rnd_reward_tensor = self._normalize_intrinsic_rewards(mse_tensor) + + rnd_reward_np = rnd_reward_tensor.detach().cpu().numpy() + self._log_final_metrics(rnd_reward_np, obs_batch_tmp, self.estimate_cnt_rnd) + self.estimate_cnt_rnd += 1 + if self._tb_logger: + self._tb_logger.add_scalar('rnd_reward_model/rnd_reward_max', rnd_reward_np.max(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/rnd_reward_mean', rnd_reward_np.mean(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/rnd_reward_min', rnd_reward_np.min(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/rnd_reward_std', rnd_reward_np.std(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_reward_max', extrinsic_normalized.max(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_reward_mean', extrinsic_normalized.mean(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_reward_min', extrinsic_normalized.min(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_reward_std', extrinsic_normalized.std(), self.estimate_cnt_rnd) + + cur_w = self._intrinsic_weight(self.estimate_cnt_rnd) + if self._tb_logger is not None: + self._tb_logger.add_scalar('rnd_reward_model/intrinsic_weight', cur_w, self.estimate_cnt_rnd) + + + rnd_reward_flat = rnd_reward_np.reshape(batch_size * T, 1) if self.intrinsic_reward_type == 'add': - if self.cfg.extrinsic_reward_norm: - target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max + rnd_reward * self.cfg.intrinsic_reward_weight - else: - target_reward_augmented = target_reward_augmented + rnd_reward * self.cfg.intrinsic_reward_weight + target_reward_augmented = extrinsic_normalized + rnd_reward_flat * cur_w elif self.intrinsic_reward_type == 'new': - if self.cfg.extrinsic_reward_norm: - target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max + target_reward_augmented = rnd_reward_flat * cur_w elif self.intrinsic_reward_type == 'assign': - target_reward_augmented = rnd_reward - - self.tb_logger.add_scalar('augmented_reward/reward_max', np.max(target_reward_augmented), self.estimate_cnt_rnd) - self.tb_logger.add_scalar('augmented_reward/reward_mean', np.mean(target_reward_augmented), - self.estimate_cnt_rnd) - self.tb_logger.add_scalar('augmented_reward/reward_min', np.min(target_reward_augmented), self.estimate_cnt_rnd) - self.tb_logger.add_scalar('augmented_reward/reward_std', np.std(target_reward_augmented), self.estimate_cnt_rnd) - - # reshape to (target_reward_augmented.shape[0], 6, 1) - target_reward_augmented = np.reshape(target_reward_augmented, (batch_size, 6, 1)) + target_reward_augmented = rnd_reward_flat + else: + target_reward_augmented = extrinsic_normalized + + if self._tb_logger is not None: + self._tb_logger.add_scalar('rnd_reward_model/augmented_reward_max', np.max(target_reward_augmented), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/augmented_reward_mean', np.mean(target_reward_augmented),self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/augmented_reward_min', np.min(target_reward_augmented), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/augmented_reward_std', np.std(target_reward_augmented), self.estimate_cnt_rnd) + + if self.adjust_value_with_intrinsic: + target_values = np.asarray(data[1][1], dtype=np.float32).reshape(batch_size, -1) + augmented_rewards_seq = target_reward_augmented.reshape(batch_size, T) + value_mask = np.asarray(data[0][3], dtype=np.float32) + value_mask = value_mask[:, :target_values.shape[1]] + if self.intrinsic_reward_type == 'add': + original_rewards_seq = extrinsic_normalized.reshape(batch_size, T) + delta_seq = augmented_rewards_seq - original_rewards_seq + delta_returns = self._discount_cumsum(delta_seq, self.discount_factor) + delta_returns = (delta_returns[:, :target_values.shape[1]] * value_mask).astype(np.float32) + target_values_augmented = target_values + delta_returns + else: + discounted_returns = self._discount_cumsum(augmented_rewards_seq, self.discount_factor) + target_values_augmented = (discounted_returns[:, :target_values.shape[1]] * value_mask).astype(np.float32) + data[1][1] = target_values_augmented + + if self._tb_logger is not None and self.adjust_value_with_intrinsic: + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_value_max', target_values.max(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_value_mean', target_values.mean(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_value_min', target_values.min(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/extrinsic_value_std', target_values.std(), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/augmented_value_max', np.max(target_values_augmented), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/augmented_value_mean', np.mean(target_values_augmented),self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/augmented_value_min', np.min(target_values_augmented), self.estimate_cnt_rnd) + self._tb_logger.add_scalar('rnd_reward_model/augmented_value_std', np.std(target_values_augmented), self.estimate_cnt_rnd) + + # reshape to (target_reward_augmented.shape[0], T, 1) + target_reward_augmented = np.reshape(target_reward_augmented, (batch_size, T, 1)) + # TODO why? batchsizw * T -> batchsize * T * 1 data[1][0] = target_reward_augmented train_data_augmented = data return train_data_augmented - def collect_data(self, data: list) -> None: - # TODO(pu): now we only collect the first 300 steps of each game segment. - collected_transitions = np.concatenate([game_segment.obs_segment[:300] for game_segment in data[0]], axis=0) - if self.input_type == 'latent_state': - with torch.no_grad(): - self.train_latent_state.extend( - self.representation_network(torch.from_numpy(collected_transitions).to(self.device))) - elif self.input_type == 'obs': - self.train_obs.extend(to_tensor(collected_transitions).to(self.device)) - elif self.input_type == 'obs_latent_state': - self.train_obs.extend(to_tensor(collected_transitions).to(self.device)) - - def clear_old_data(self) -> None: - if self.input_type == 'latent_state': - if len(self.train_latent_state) >= self.cfg.rnd_buffer_size: - self.train_latent_state = self.train_latent_state[-self.cfg.rnd_buffer_size:] - elif self.input_type == 'obs': - if len(self.train_obs) >= self.cfg.rnd_buffer_size: - self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:] - elif self.input_type == 'obs_latent_state': - if len(self.train_obs) >= self.cfg.rnd_buffer_size: - self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:] - def state_dict(self) -> Dict: return self.reward_model.state_dict() @@ -335,3 +687,140 @@ def clear_data(self): def train(self): pass + + def collect_data(self, data) -> None: + if data is None or len(data) == 0: + return + segments = [game_segment.obs_segment for game_segment in data[0] if len(game_segment.obs_segment)] + if not segments: + return + concatenated = np.concatenate(segments, axis=0) + flattened = self._flatten_obs_batch(concatenated) + self._update_visit_counts(flattened) + + # ---------------------- 可视化辅助(新增) ---------------------- # + def _select_peaks(self, y: np.ndarray, k: int) -> List[int]: + order = np.argsort(-y) + picked: List[int] = [] + for i in order: + if len(picked) >= k: + break + picked.append(int(i)) + picked.sort() + return picked + + def _obs_to_rgb(self, obs_any: np.ndarray) -> np.ndarray: + x = np.asarray(obs_any) + x = np.squeeze(x) + if x.ndim == 3: + if x.shape[0] in (1, 3): # CHW + if x.shape[0] == 3: + img = np.transpose(x, (1, 2, 0)) + elif x.shape[0] == 1: + img = x[0] + elif x.shape[-1] in (1, 3): # HWC + if x.shape[-1] == 3: + img = x + elif x.shape[-1] == 1: + img = x[..., 0] + else: + pass + img = (img * 255.0).clip(0, 255).astype(np.uint8) + return img + + def UpdateFuncAnimation(self, all_obs_per_episode: List[np.ndarray]) -> None: + """ + Overview: + 给定一条 episode 的完整 obs 序列(list,每个元素为 (C,H,W) 或 (H,W,C) 的 numpy 数组), + 使用当前 RND 模型重新计算这一条轨迹上每一步的 intrinsic reward, + 并画出: + - 上方:若干关键帧(intrinsic reward 较大的若干步的观测) + - 下方:整条时间线上的 intrinsic reward 曲线 + + 图像会被写入 TensorBoard(若 _tb_logger 不为 None): + tag = "rnd_visual/episode_intrinsic_timeline" + step = self.estimate_cnt_rnd + """ + if not all_obs_per_episode: + return + + if not getattr(self, 'enable_image_logging', False) or self._tb_logger is None: + return + + # 1) 堆叠成 (T, ...) 方便后续处理 + obs_array = np.stack(all_obs_per_episode, axis=0) # (T, C, H, W) 或 (T, H, W, C) + # 2) 展平成 (T, *obs_shape),和 _flatten_obs_batch 保持一致 + flat_obs = self._flatten_obs_batch(obs_array) # (T, *obs_shape) + if flat_obs.size == 0: + return + # 3) 准备输入 + 归一化(与 estimate 中逻辑一致) + inputs = self._prepare_inputs_from_obs(flat_obs) + + # 更新输入 running mean/std,再做标准化 + norm_inputs = self._normalize_inputs(inputs.clone()) + + # 4) 通过 RND 模型得到每一步 intrinsic reward(MSE) + with torch.no_grad(): + predict_feature, target_feature = self.reward_model(norm_inputs) + mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1) # (T,) + + mse_np = mse.detach().cpu().numpy() + mse_tensor = torch.from_numpy(mse_np).to(self._device) + rnd_reward_tensor = self._normalize_intrinsic_rewards(mse_tensor) + rnd_rewards = rnd_reward_tensor.detach().cpu().numpy().reshape(-1) # (T,) + + T = rnd_rewards.shape[0] + steps = np.arange(T, dtype=np.int32) + # 5) 选出若干“峰值位置”,对应关键帧 + k_cfg = int(getattr(self.cfg, 'peaks_topk', 10)) + k = max(1, min(k_cfg, T)) + peak_indices = self._select_peaks(rnd_rewards, k=k) + # 6) 把对应 obs 转成 RGB / Gray 图像 + frames: List[np.ndarray] = [] + for idx in peak_indices: + frames.append(self._obs_to_rgb(flat_obs[idx])) + + # 7) 画图:上方一行 key frames,下方一行 reward 曲线 + ncols = k + fig = plt.figure(figsize=(ncols * 1.5, 3.8 + 1.5), dpi=120) + gs = fig.add_gridspec(2, ncols, height_ratios=[1, 2]) + + # 上排:关键帧 + for i in range(k): + ax = fig.add_subplot(gs[0, i]) + img = frames[i] + if img.ndim == 2: + ax.imshow(img, cmap='gray') + else: + ax.imshow(img) + ax.set_axis_off() + ax.set_title(f"t={peak_indices[i]}", fontsize=8, pad=2) + + # 下排:时间线曲线 + ax_line = fig.add_subplot(gs[1, :]) + ax_line.plot(steps, rnd_rewards, linewidth=1.0) + ax_line.set_xlabel("Episode step") + ax_line.set_ylabel("Intrinsic reward") + + for i, idx in enumerate(peak_indices): + ax_line.scatter([steps[idx]], [rnd_rewards[idx]], s=14) + ax_line.annotate( + str(i + 1), + (steps[idx], rnd_rewards[idx]), + textcoords="offset points", + xytext=(0, 8), + ha="center", + fontsize=8, + ) + + fig.tight_layout() + + # 8) 写进 TensorBoard + global_step = int(self.estimate_cnt_rnd) + self._tb_logger.add_figure("rnd_visual/episode_intrinsic_timeline", fig, global_step) + plt.close(fig) + + logging.info( + "[RND] UpdateFuncAnimation: logged episode intrinsic timeline | T=%d | peaks=%d | step=%d", + T, k, global_step, + ) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 4d3b1b740..07c7c0157 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -219,11 +219,8 @@ def _compute_priorities(self, i: int, pred_values_lst: List[float], search_value # A small constant (1e-6) is added to the results to avoid zero priorities. This # is done because zero priorities could potentially cause issues in some scenarios. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: # priorities is None -> use the max priority for all newly collected data priorities = None diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 2a70feea5..5ccbde5ab 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -197,6 +197,7 @@ def eval( envstep: int = -1, n_episode: Optional[int] = None, return_trajectory: bool = False, + reward_model = None, ) -> Tuple[bool, float]: """ Overview: @@ -227,7 +228,9 @@ def eval( # initializations init_obs = self._env.ready_obs - + if self.policy_config.use_rnd_model and self.policy_config.model.world_model_cfg.obs_type == 'image': + obs_for_rnd = [[] for i in range(env_nums)] + retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to @@ -306,6 +309,10 @@ def eval( stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + if self.policy_config.use_rnd_model and self.policy_config.model.world_model_cfg.obs_type == 'image': + for idx, env_id in enumerate(ready_env_id): + obs_for_rnd[env_id].append(stack_obs[idx]) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== @@ -386,7 +393,6 @@ def eval( timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(obs['chance']) - dones[env_id] = done if episode_timestep.done: # Env reset is done by env_manager automatically. @@ -460,6 +466,13 @@ def eval( duration = self._timer.value episode_return = eval_monitor.get_episode_return() + if self.policy_config.use_rnd_model and self.policy_config.model.world_model_cfg.obs_type == 'image': + max_return = max(episode_return) + max_idx = episode_return.index(max_return) + obs_for_rnd_res = obs_for_rnd[max_idx] + if max_return > 0: + reward_model.UpdateFuncAnimation(obs_for_rnd_res) + info = { 'train_iter': train_iter, 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 46cc016bc..3a25a04a8 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -1,7 +1,7 @@ import logging import time from collections import deque, namedtuple -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict import numpy as np import torch diff --git a/zoo/atari/config/atari_env_action_space_map.py b/zoo/atari/config/atari_env_action_space_map.py index e2090586d..db51e063c 100644 --- a/zoo/atari/config/atari_env_action_space_map.py +++ b/zoo/atari/config/atari_env_action_space_map.py @@ -27,4 +27,6 @@ 'SeaquestNoFrameskip-v4': 18, 'BoxingNoFrameskip-v4': 18, 'BreakoutNoFrameskip-v4': 4, + 'MontezumaRevengeNoFrameskip-v4': 18, + 'VentureNoFrameskip-v4': 18, }) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_segment_config.py b/zoo/atari/config/atari_unizero_segment_config.py index dec2ee4d2..e8659da52 100644 --- a/zoo/atari/config/atari_unizero_segment_config.py +++ b/zoo/atari/config/atari_unizero_segment_config.py @@ -21,11 +21,12 @@ def main(env_id, seed): infer_context_length = 4 # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. - buffer_reanalyze_freq = 1/50 + buffer_reanalyze_freq = 1/100000 # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. reanalyze_partition = 0.75 + norm_type ="LN" # ====== only for debug ===== # collector_env_num = 2 @@ -41,7 +42,7 @@ def main(env_id, seed): env=dict( stop_value=int(1e6), env_id=env_id, - observation_shape=(3, 64, 64), + observation_shape=(3, 96, 96), gray_scale=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -54,11 +55,18 @@ def main(env_id, seed): policy=dict( learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 model=dict( - observation_shape=(3, 64, 64), + observation_shape=(3, 96, 96), action_space_size=action_space_size, reward_support_range=(-300., 301., 1.), value_support_range=(-300., 301., 1.), + norm_type=norm_type, world_model_cfg=dict( + use_new_cache_manager=False, + norm_type=norm_type, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + # predict_latent_loss_type='mse', + predict_latent_loss_type='cos_sim', # only for latent state layer_norm support_size=601, policy_entropy_weight=5e-3, continuous_action_space=False, @@ -73,21 +81,58 @@ def main(env_id, seed): obs_type='image', env_num=max(collector_env_num, evaluator_env_num), num_simulations=num_simulations, + game_segment_length=game_segment_length, + use_priority=True, rotary_emb=False, + optim_type='AdamW_mix_lr_wdecay', ), ), + optim_type='AdamW_mix_lr_wdecay', + weight_decay=1e-2, # TODO: encoder 5*wd, transformer wd, head 0 + learning_rate=0.0001, # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, - use_augmentation=False, + + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-3, + target_entropy_start_ratio =0.98, + target_entropy_end_ratio =0.7, + target_entropy_decay_steps = 100000, # 例如,在300k次迭代后达到最终值 + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在300k次迭代后达到最终值 + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0.1, #for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + + # use_augmentation=False, + use_augmentation=True, manual_temperature_decay=False, threshold_training_steps_for_final_temperature=int(2.5e4), - use_priority=False, + use_priority=True, + priority_prob_alpha=1, + priority_prob_beta=1, num_unroll_steps=num_unroll_steps, update_per_collect=None, replay_ratio=replay_ratio, batch_size=batch_size, - optim_type='AdamW', - learning_rate=0.0001, num_simulations=num_simulations, num_segments=num_segments, td_steps=5, @@ -126,7 +171,7 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_unizero_segment - main_config.exp_name = f'data_lz/data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'./data_lz/data_unizero_atari/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) diff --git a/zoo/atari/config/atari_unizero_segment_rnd_config.py b/zoo/atari/config/atari_unizero_segment_rnd_config.py new file mode 100644 index 000000000..fc57ad393 --- /dev/null +++ b/zoo/atari/config/atari_unizero_segment_rnd_config.py @@ -0,0 +1,254 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + num_segments = 8 + game_segment_length = 20 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + batch_size = 64 + num_layers = 2 + replay_ratio = 0.25 + num_unroll_steps = 10 + infer_context_length = 4 + collect_num_simulations = 50 + eval_num_simulations = 50 + num_channels=128 + num_res_blocks=2 + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq = 1/100000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + norm_type ="LN" + + # ====== only for debug ===== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 10 + # batch_size = 5 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 96, 96), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + reward_model=dict( + type='rnd_unizero', + intrinsic_reward_type='add', + input_type='obs', # options: ['obs', 'latent_state', 'obs_latent_state'] + activation_type='LeakyReLU', + enable_image_logging=True, + + # —— 新增:自适应权重调度 —— # + use_intrinsic_weight_schedule=True, # 打开自适应权重 + intrinsic_weight_mode='cosine', # 'cosine' | 'linear' | 'constant' + intrinsic_weight_warmup=10000, # 前多少次 estimate 权重=0 + intrinsic_weight_ramp=20000, # 从min升到max所需的 estimate 数 + intrinsic_weight_min=0.0, + intrinsic_weight_max=0.05, + + obs_shape=(3, 96, 96), + latent_state_dim=256, + hidden_size_list=[32, 64, 64], + output_dim=512, + learning_rate=3e-4, + weight_decay=1e-4, + input_norm=True, + input_norm_clamp_max=10, + input_norm_clamp_min=-10, + + intrinsic_norm=True, + intrinsic_norm_clamp_min=-5, + intrinsic_norm_clamp_max=5, + + extrinsic_sign=False, + extrinsic_norm=False, + extrinsic_norm_clamp_min=-5, + extrinsic_norm_clamp_max=5, + adjust_value_with_intrinsic=False, + discount_factor=0.997, + + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + model=dict( + num_channels=num_channels, + num_res_blocks=num_res_blocks, + observation_shape=(3, 96, 96), + action_space_size=action_space_size, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + norm_type=norm_type, + world_model_cfg=dict( + use_new_cache_manager=False, + norm_type=norm_type, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + # predict_latent_loss_type='cos_sim', # only for latent state layer_norm + support_size=601, + policy_entropy_weight=5e-3, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + num_simulations=num_simulations, + game_segment_length=game_segment_length, + use_priority=False, + rotary_emb=False, + optim_type='AdamW_mix_lr_wdecay', + ), + ), + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + optim_type='AdamW_mix_lr_wdecay', + weight_decay=1e-2, # TODO: encoder 5*wd, transformer wd, head 0 + learning_rate=0.0001, + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=False, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-3, + target_entropy_start_ratio =0.98, + target_entropy_end_ratio =0.7, + target_entropy_decay_steps = 100000, # 例如,在300k次迭代后达到最终值 + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=False, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在300k次迭代后达到最终值 + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0.0, #for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + + use_augmentation=False, + # use_augmentation=True, + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(2.5e4), + use_priority=False, + priority_prob_alpha=1, + priority_prob_beta=1, + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=replay_ratio, + batch_size=batch_size, + num_simulations=num_simulations, + num_segments=num_segments, + td_steps=5, + train_start_after_envsteps=0, + game_segment_length=game_segment_length, + grad_clip_value=5, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + # ============= RND specific settings ============= + use_rnd_model=True, + random_collect_data=True, + use_momentum_representation_network=True, + target_model_for_intrinsic_reward_update_type='assign', + target_update_freq_for_intrinsic_reward=1000, + target_update_theta_for_intrinsic_reward=0.005, + bp_update_sync=True, + multi_gpu=False, + ), + ) + atari_unizero_config = EasyDict(atari_unizero_config) + main_config = atari_unizero_config + + atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_create_config = EasyDict(atari_unizero_create_config) + create_config = atari_unizero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_unizero_segment_with_reward_model + main_config.exp_name = (f'./data_lz/data_unizero_atari_rnd/{env_id[:-14]}_obs_latent_w_10/rnd_{main_config.reward_model.intrinsic_reward_type}_' + f'{main_config.reward_model.input_type}_wmax_{main_config.reward_model.intrinsic_weight_max}_input_norm_{main_config.reward_model.input_norm}_intrinsic_norm_{main_config.reward_model.intrinsic_norm}_use_intrinsic_weight_schedule_{main_config.reward_model.use_intrinsic_weight_schedule}/' + f'{main_config.policy.model.world_model_cfg.predict_latent_loss_type}_adaptive_entropy_{main_config.policy.use_adaptive_entropy_weight}_use_priority_{main_config.policy.use_priority}_encoder_clip_{main_config.policy.use_encoder_clip_annealing}_label_smoothing_{main_config.policy.label_smoothing_eps}_use_aug_{main_config.policy.use_augmentation}_ncha_{num_channels}_nres_{num_res_blocks}/') + # main_config.exp_name = ( + # f'./data_lz/data_unizero_atari_rnd/{env_id[:-14]}/' + # f'{env_id[:-14]}_rnd_w_{main_config.reward_model.intrinsic_reward_weight}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_' + # f'nlayer{num_layers}_numsegments-{num_segments}_' + # f'gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_' + # f'bs{batch_size}_seed{seed}' + # ) + train_unizero_segment_with_reward_model( + [main_config, create_config], + seed=seed, + model_path=main_config.policy.model_path, + max_env_step=max_env_step, + ) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + # parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--env', type=str, help='The environment to use', default='VentureNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + main(args.env, args.seed)