diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index f33521086..4d5f87632 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -48,10 +48,10 @@ def train_muzero( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \ + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_history', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'" - if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']: + if create_cfg.policy.type in ['muzero', 'muzero_history', 'muzero_context', 'muzero_rnn_full_obs']: from lzero.mcts import MuZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index fe5e28090..6db93df8a 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -37,7 +37,7 @@ def default_config(cls: type) -> EasyDict: # (bool) Whether to use the root value in the reanalyzing part. Please refer to EfficientZero paper for details. use_root_value=False, # (int) The number of samples required for mini inference. - mini_infer_size=10240, + mini_infer_size=20480, # (str) The type of sampled data. The default is 'transition'. Options: 'transition', 'episode'. sample_type='transition', ) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 1e4c9d698..079073e02 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -61,6 +61,13 @@ def __init__(self, cfg: dict): self.sample_times = 0 self.active_root_num = 0 + self.history_length = self._cfg.history_length + if self.history_length > 1: + self.num_unroll_steps = self._cfg.num_unroll_steps + self.history_length + else: + self.num_unroll_steps = self._cfg.num_unroll_steps + + def reset_runtime_metrics(self): """ Overview: @@ -138,6 +145,7 @@ def sample( batch_size, self._cfg.reanalyze_ratio ) # target reward, target value + # import ipdb;ipdb.set_trace() batch_rewards, batch_target_values = self._compute_target_reward_value( reward_value_context, policy._target_model ) @@ -191,21 +199,21 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: pos_in_game_segment = pos_in_game_segment_list[i] actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() + self.num_unroll_steps].tolist() # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid # mask_tmp = [1. for i in range(len(actions_tmp))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # mask_tmp += [0. for _ in range(self.num_unroll_steps + 1 - len(mask_tmp))] # TODO: the child_visits after position in the segment (with padded part) may not be updated # So the corresponding position should not be used in the training mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] - mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + mask_tmp += [0. for _ in range(self.num_unroll_steps + 1 - len(mask_tmp))] # pad random action actions_tmp += [ np.random.randint(0, game.action_space_size) - for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + for _ in range(self.num_unroll_steps - len(actions_tmp)) ] # obtain the input observations @@ -213,7 +221,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # e.g. stack+num_unroll_steps = 4+5 obs_list.append( game_segment_list[i].get_unroll_obs( - pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + pos_in_game_segment_list[i], num_unroll_steps=self.num_unroll_steps, padding=True ) ) action_list.append(actions_tmp) @@ -299,7 +307,7 @@ def _prepare_reward_value_context( # prepare the corresponding observations for bootstrapped values o_{t+k} # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] - game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) + game_obs = game_segment.get_unroll_obs(state_index + td_steps, self.num_unroll_steps) rewards_list.append(game_segment.reward_segment) @@ -309,7 +317,7 @@ def _prepare_reward_value_context( truncation_length = game_segment_len - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + for current_index in range(state_index, state_index + self.num_unroll_steps + 1): # get the bootstrapped target obs td_steps_list.append(td_steps) # index of bootstrapped obs o_{t+td_steps} @@ -400,9 +408,9 @@ def _prepare_policy_reanalyzed_context( child_visits.append(game_segment.child_visit_segment) root_values.append(game_segment.root_value_segment) # prepare the corresponding observations - game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) + game_obs = game_segment.get_unroll_obs(state_index, self.num_unroll_steps) - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + for current_index in range(state_index, state_index + self.num_unroll_steps + 1): if current_index < game_segment_len: # original policy_mask.append(1) @@ -436,10 +444,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) transition_batch_size = len(value_obs_list) + # if self.history_length>1: + # game_segment_batch_size = len(pos_in_game_segment_list) + # transition_batch_size = transition_batch_size - (self.history_length-1)*game_segment_batch_size + batch_target_values, batch_rewards = [], [] with torch.no_grad(): + # import ipdb;ipdb.set_trace() value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - + if transition_batch_size > self._cfg.mini_infer_size: + print(f"transition_batch_size > mini_infer_size:{transition_batch_size > self._cfg.mini_infer_size}") # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) network_output = [] @@ -448,7 +462,12 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device) # calculate the target value - m_output = model.initial_inference(m_obs) + # import ipdb;ipdb.set_trace() + # print(f"m_obs.shape: {m_obs.shape}") + try: + m_output = model.initial_inference(m_obs) + except Exception as e: + print(e) if not model.training: # if not in training, obtain the scalars of the value/reward @@ -469,6 +488,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # use the predicted values value_list = concat_output_value(network_output) + # print(f"value_list.shape: {value_list.shape}") + # get last state value if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: # TODO(pu): for board_games, very important, to check @@ -498,7 +519,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A truncation_length = game_segment_len_non_re - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + for current_index in range(state_index, state_index + self.num_unroll_steps + 1): bootstrap_index = current_index + td_steps_list[value_index] for i, reward in enumerate(reward_list[current_index:bootstrap_index]): if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: @@ -544,7 +565,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # for board games policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens, action_mask_segment, \ to_play_segment = policy_re_context - # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + # transition_batch_size = game_segment_batch_size * (self.num_unroll_steps + 1) transition_batch_size = len(policy_obs_list) game_segment_batch_size = len(pos_in_game_segment_list) @@ -623,7 +644,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: for state_index, child_visit, game_index in zip(pos_in_game_segment_list, child_visits, batch_index_list): target_policies = [] - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + for current_index in range(state_index, state_index + self.num_unroll_steps + 1): distributions = roots_distributions[policy_index] searched_value = roots_values[policy_index] @@ -694,10 +715,10 @@ def _compute_target_policy_non_reanalyzed( pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context game_segment_batch_size = len(pos_in_game_segment_list) - transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + transition_batch_size = game_segment_batch_size * (self.num_unroll_steps + 1) to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, self.num_unroll_steps ) if self._cfg.model.continuous_action_space is True: @@ -710,7 +731,9 @@ def _compute_target_policy_non_reanalyzed( [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) ] else: + # import ipdb;ipdb.set_trace() legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + # legal_actions = None with torch.no_grad(): policy_index = 0 @@ -721,7 +744,7 @@ def _compute_target_policy_non_reanalyzed( pos_in_game_segment_list): target_policies = [] - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + for current_index in range(state_index, state_index + self.num_unroll_steps + 1): if current_index < game_segment_len: policy_mask.append(1) # NOTE: child_visit is already a distribution diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index a509f1360..db8474b00 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -57,6 +57,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.zero_obs_shape = config.model.observation_shape elif len(config.model.observation_shape) == 3: # image obs input, e.g. atari environments + # print(f'NOTE: config.model.image_channel:{config.model.image_channel}') self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] diff --git a/lzero/mcts/utils.py b/lzero/mcts/utils.py index 407f5e2ba..1c514bc60 100644 --- a/lzero/mcts/utils.py +++ b/lzero/mcts/utils.py @@ -97,11 +97,18 @@ def prepare_observation(observation_list, model_type='conv'): Returns: - np.ndarray: Reshaped array of observations. """ - assert model_type in ['conv', 'mlp', 'conv_context', 'mlp_context'], "model_type must be either 'conv' or 'mlp'" + assert model_type in ['conv', 'conv_history', 'mlp', 'conv_context', 'mlp_context'], "model_type must be either 'conv' or 'mlp'" observation_array = np.array(observation_list) + + # try: + # observation_array = np.array(observation_list) + # except Exception as e: + # print(e) + # import ipdb;ipdb.set_trace() + batch_size = observation_array.shape[0] - if model_type in ['conv', 'conv_context']: + if model_type in ['conv', 'conv_history', 'conv_context']: if observation_array.ndim == 3: # Add a channel dimension if it's missing observation_array = observation_array[..., np.newaxis] diff --git a/lzero/model/common.py b/lzero/model/common.py index 76cd591f2..92493f28b 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -1001,7 +1001,7 @@ def __init__( if observation_shape[1] == 96: latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) elif observation_shape[1] == 64: - latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) + latent_shape = (int(observation_shape[1] / 8), int(observation_shape[2] / 8)) if norm_type == 'BN': self.norm_value = nn.BatchNorm2d(value_head_channels) diff --git a/lzero/model/muzero_model_history.py b/lzero/model/muzero_model_history.py new file mode 100644 index 000000000..1ef5b6bbd --- /dev/null +++ b/lzero/model/muzero_model_history.py @@ -0,0 +1,911 @@ +""" +Overview: + BTW, users can refer to the unittest of these model templates to learn how to use them. +""" +from typing import Optional, Tuple + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray +import math +from typing import Sequence, Tuple, List +import torch.nn.init as init + +import torch.nn.functional as F +from .common import MZNetworkOutput, PredictionNetwork, FeatureAndGradientHook, MLP_V2, DownSample, SimNorm +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from lzero.model.muzero_model import MuZeroModel +from lzero.model.muzero_model_mlp import DynamicsNetworkVector, PredictionNetworkMLP +from lzero.model.unizero_world_models.transformer import Transformer, TransformerConfig +import numpy as np + +class RepresentationNetworkMemoryEnv(nn.Module): + def __init__( + self, + image_shape: Sequence = (3, 5, 5), # 单步输入 shape,每一步的 channel 为 image_shape[0] + embedding_size: int = 100, + channels: List[int] = [16, 32, 64], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], + activation: nn.Module = nn.GELU(approximate='tanh'), + normalize_pixel: bool = False, + group_size: int = 8, + history_length: int = 20, + fusion_mode: str = 'mean', # 可选: 'mean', 'transformer', 其它未来方式 + **kwargs, + ): + """ + 表征网络,用于 MemoryEnv,将2D图像 obs 编码为 latent state,并支持对多历史步进行融合。 + 除了对单步图像进行编码(如 image_shape 为 (3, 5, 5)),本网络扩展为: + 1. 根据输入通道数(total_channels)与单步输入通道数(image_shape[0])的比值,划分为多个历史步, + 即输入 x 的 shape 为 [B, total_channels, W, H],其中 total_channels 应为 (history_length * image_shape[0])。 + 2. 分别编码每一步,输出 latent series,形状为 [B, history_length, embedding_size]。 + 3. 根据 fusion_mode 对 history_length 个 latent 进行融合,得到最终 latent state,形状为 [B, embedding_size]。 + """ + super(RepresentationNetworkMemoryEnv, self).__init__() + self.image_shape = image_shape + self.single_step_in_channels = image_shape[0] + self.embedding_size = embedding_size + self.normalize_pixel = normalize_pixel + self.fusion_mode = fusion_mode + + # 构建单步 CNN encoder 网络(和 LatentEncoderForMemoryEnv 保持一致的基本结构) + self.channels = [image_shape[0]] + list(channels) + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.Conv2d( + in_channels=self.channels[i], + out_channels=self.channels[i + 1], + kernel_size=kernel_sizes[i], + stride=strides[i], + padding=kernel_sizes[i] // 2, # 保持 feature map 大小不变 + ) + ) + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + # 自适应池化,输出形状固定为 1x1 + layers.append(nn.AdaptiveAvgPool2d(1)) + self.cnn = nn.Sequential(*layers) + + # 全连接层将 CNN 输出转为 embedding 表征 + self.linear = nn.Linear(self.channels[-1], embedding_size, bias=False) + init.kaiming_normal_(self.linear.weight, mode='fan_out', nonlinearity='relu') + + self.final_norm = nn.LayerNorm(self.embedding_size, eps=1e-5) + # 如果使用 transformer 聚合,则初始化 transformer 模块 + if self.fusion_mode == 'transformer': + # 假设 history_length 在训练时可变,初始化时无法确定, + # 这里采用一个默认的 tokens_per_block 值,后续也可以根据需要对 transformer 配置进行扩展 + transformer_config = TransformerConfig( + tokens_per_block=history_length, # 每个 block 的 token 数量 + max_blocks=1, # 此处只融合一次 + attention="causal", + num_layers=2, # 可根据需求调整 transformer 层数 + num_heads=8, # 可根据需求调整 + embed_dim=self.embedding_size, # 输入的 embed 维度,与 CNN 输出保持一致 + embed_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + gru_gating=False, + ) + self.transformer = Transformer(transformer_config) + else: + self.transformer = None + + def forward_single(self, x: torch.Tensor) -> torch.Tensor: + """ + 对单步输入进行编码: + x: [B, single_step_in_channels, W, H] + 返回: [B, embedding_size] + """ + if self.normalize_pixel: + x = x / 255.0 + x = self.cnn(x.float()) # 输出形状 (B, C, 1, 1) + x = torch.flatten(x, start_dim=1) # 转换为形状 (B, C) + x = self.linear(x) # (B, embedding_size) + x = self.final_norm(x) # 归一化处理 + return x + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: 输入 tensor,形状 [B, total_channels, W, H], + 其中 total_channels 应为 (history_length * single_step_in_channels) + 例如输入 shape 可为: [8, 12, 5, 5],其中 12 = 4 * 3,表示 4 个历史步,每步 3 个 channel。 + Returns: + latent_series: 各步的 latent 表征,形状为 [B, history_length, embedding_size] + latent_fused: 融合后的 latent 表征,形状为 [B, embedding_size] + 当 fusion_mode 为 "transformer" 时,latent_fused 为 transformer 输出序列中最后一个 timestep 的 latent。 + """ + B, total_channels, W, H = x.shape + if total_channels % self.single_step_in_channels != 0: + raise ValueError( + f"总通道数 {total_channels} 不能整除单步通道数 {self.single_step_in_channels}" + ) + history_length = total_channels // self.single_step_in_channels + + latent_series = [] + for t in range(history_length): + # 取第 t 个历史步的数据 + x_t = x[:, t * self.single_step_in_channels:(t + 1) * self.single_step_in_channels, :, :] + latent_t = self.forward_single(x_t) # [B, embedding_size] + latent_series.append(latent_t.unsqueeze(1)) # 在时间维度上扩展 + + latent_series = torch.cat(latent_series, dim=1) # [B, history_length, embedding_size] + + # 根据 fusion_mode 进行融合 + if self.fusion_mode == 'mean': + latent_fused = latent_series.mean(dim=1) + elif self.fusion_mode == 'transformer': + # 如果 latent_series 的历史长度与 transformer 配置不匹配, + # 可通过 padding 或截断保证输入 transformer 的序列长度与其配置一致 + # 这里假设 latent_series.shape[1] 即 history_length 与 tokens_per_block 一致, + # 否则需要进行相关预处理。 + transformer_out = self.transformer(latent_series) # 输出形状 (B, history_length, embedding_size) + # import ipdb;ipdb.set_trace() + + # 取最后一步 latent state 作为聚合结果 + latent_fused = transformer_out[:, -1, :] + else: + # 其它融合方式:例如先拼接后通过全连接层融合 + B, T, E = latent_series.shape + latent_concat = latent_series.view(B, -1) # [B, T * E] + fusion_fc = nn.Linear(T * E, E).to(x.device) + latent_fused = fusion_fc(latent_concat) + + # 返回两种结果,本例中也可以只返回融合后的结果 + # return latent_series, latent_fused + return latent_fused + + + +# 修改后的扩展版本的 RepresentationNetwork +class RepresentationNetwork(nn.Module): + def __init__( + self, + observation_shape: Sequence = (3, 64, 64), # 单步输入 shape, 每一步3个channel + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_sim_norm: bool = False, + fusion_mode: str = 'mean', # 可以扩展为其它融合方式 + ) -> None: + """ + 表征网络,将2D图像 obs 编码为 latent state。 + + 除了本来的单步编码(例如 obs_with_history[:,:3,:,:]),该网络扩展为: + 1. 根据输入的第二维(通道维度)划分为多个历史步(每步3个 channel)。 + 2. 分别计算每一步的 latent state,输出 shape 为 [B, T, num_channels, H_out, W_out]。 + 3. 将 T 步的信息融合(例如均值融合)得到最终 latent state,其 shape 为 [B, num_channels, H_out, W_out]。 + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + # 这里单步输入channels为 observation_shape[0],一般设置为 3 + self.single_step_in_channels = observation_shape[0] + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(self.single_step_in_channels, num_channels, kernel_size=3, stride=1, padding=1, bias=False) + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + self.embedding_dim = embedding_dim + + # self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + # group=1 等价于 layer normalization 每个 sample 内部归一化 + self.final_norm = nn.GroupNorm(1, num_channels) + + # 融合模式,当前仅支持均值融合;可以扩展为其它方式,例如使用 1D 卷积融合时间步信息 + self.fusion_mode = fusion_mode + + def forward_single(self, x: torch.Tensor) -> torch.Tensor: + """ + 处理单步输入: + x: [B, single_step_in_channels, W, H] + 返回: [B, num_channels, W_out, H_out] + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + x = self.final_norm(x) + + return x + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: 输入 tensor,shape 为 [B, total_channels, W, H],其中 total_channels 应为 (history_length * single_step_in_channels)。 + 例如 collect/eval 阶段的输入 shape: [8, 12, 64, 64],其中 12 = 4 * 3,表示 4 个历史步,每步 3 个 channel。 + + Returns: + latent_series: 各步的 latent state,shape 为 [B, history_length, num_channels, W_out, H_out] + latent_fused: 融合后的 latent state,shape 为 [B, num_channels, W_out, H_out] + """ + B, total_channels, W, H = x.shape + assert total_channels % self.single_step_in_channels == 0, ( + f"Total channels {total_channels} 不能整除单步通道数 {self.single_step_in_channels}" + ) + history_length = total_channels // self.single_step_in_channels + + latent_series = [] + for t in range(history_length): + # 对应第 t 步的数据:取第 t*channel 到 (t+1)*channel + x_t = x[:, t * self.single_step_in_channels:(t + 1) * self.single_step_in_channels, :, :] + latent_t = self.forward_single(x_t) # [B, num_channels, W_out, H_out] + latent_series.append(latent_t.unsqueeze(1)) # 在时间维度上扩展 + + latent_series = torch.cat(latent_series, dim=1) # [B, history_length, num_channels, W_out, W_out] + + # import ipdb;ipdb.set_trace() + + # 根据 fusion_mode 融合历史步信息 + if self.fusion_mode == 'mean': + latent_fused = latent_series.mean(dim=1) # 均值融合, [B, num_channels, W_out, H_out] + else: + # 可增加其它融合方式,比如拼接后通过1x1卷积 + B, T, C, H_out, W_out = latent_series.shape + latent_concat = latent_series.view(B, -1, H_out, W_out) # [B, T * C, H_out, W_out] + fusion_conv = nn.Conv2d(T * C, C, kernel_size=1) + latent_fused = fusion_conv(latent_concat) + + # return latent_series, latent_fused + return latent_fused + + + +# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. +@MODEL_REGISTRY.register('MuZeroHistoryModel') +class MuZeroHistoryModel(MuZeroModel): + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + reward_head_hidden_channels: SequenceType = [32], + value_head_hidden_channels: SequenceType = [32], + policy_head_hidden_channels: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + history_length: int = 5, + fusion_mode= 'mean', # 可选: 'mean', 'transformer', 其它未来方式 + num_unroll_steps: int = 5, + use_sim_norm: bool = False, + analysis_sim_norm: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the model for MuZero w/ Context, a variant of MuZero. + This variant retains the same training settings as MuZero but diverges during inference + by employing a k-step recursively predicted latent representation at the root node, + proposed in the UniZero paper https://arxiv.org/abs/2406.10667. + Arguments: + - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - reward_head_channels (:obj:`int`): The channels of reward head. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + in MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical \ + distribution for value and reward. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for hidden states, default set it to False. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + """ + super(MuZeroHistoryModel, self).__init__() + + self.timestep = 0 + self.history_length = history_length # NOTE + self.num_unroll_steps = num_unroll_steps + + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + # to be compatible with LightZero model/policy, transform to shape: [C, W, H] + observation_shape = [1, observation_shape, 1] + + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + self.action_space_size = action_space_size + print('action_space_size:', action_space_size) + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.downsample = downsample + self.analysis_sim_norm = analysis_sim_norm + + if observation_shape[1] == 96: + latent_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) + elif observation_shape[1] == 64: + latent_size = math.ceil(observation_shape[1] / 8) * math.ceil(observation_shape[2] / 8) + elif observation_shape[1] == 5: + latent_size = 64 + + + flatten_input_size_for_reward_head = ( + (reward_head_channels * latent_size) if downsample else + (reward_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_input_size_for_value_head = ( + (value_head_channels * latent_size) if downsample else + (value_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_input_size_for_policy_head = ( + (policy_head_channels * latent_size) if downsample else + (policy_head_channels * observation_shape[1] * observation_shape[2]) + ) + + if observation_shape[1] == 5: + # MemoryEnv + embedding_size = 768 + self.representation_network = RepresentationNetworkMemoryEnv( + observation_shape, + embedding_size=embedding_size, + channels= [16, 32, 64], + group_size= 8, + history_length=self.history_length, + fusion_mode=fusion_mode, # 可选: 'mean', 'transformer', 其它未来方式 + ) + self.num_channels = num_channels + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.dynamics_network = DynamicsNetworkVector( + action_encoding_dim=self.action_encoding_dim, + num_channels=embedding_size + self.action_encoding_dim, + common_layer_num=2, + reward_head_hidden_channels=reward_head_hidden_channels, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=True, + ) + self.vector_ynamics_network = True + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=embedding_size, + value_head_hidden_channels=value_head_hidden_channels, + policy_head_hidden_channels=policy_head_hidden_channels, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + else: + # atari + self.representation_network = RepresentationNetwork( + observation_shape, + num_res_blocks, + num_channels, + downsample, + activation=activation, + norm_type=norm_type, + embedding_dim=768, + group_size=8, + use_sim_norm=use_sim_norm, # NOTE + ) + # ====== for analysis ====== + if self.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.dynamics_network = DynamicsNetwork( + observation_shape, + self.action_encoding_dim, + num_res_blocks, + num_channels + self.action_encoding_dim, + reward_head_channels, + reward_head_hidden_channels, + self.reward_support_size, + flatten_input_size_for_reward_head, + downsample, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type, + embedding_dim=768, + group_size=8, + use_sim_norm=use_sim_norm, # NOTE + ) + self.vector_ynamics_network = False + + # import ipdb;ipdb.set_trace() + + self.prediction_network = PredictionNetwork( + observation_shape, + action_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + value_head_hidden_channels, + policy_head_hidden_channels, + self.value_support_size, + flatten_input_size_for_value_head, + flatten_input_size_for_policy_head, + downsample, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # projection used in EfficientZero + if self.downsample: + # In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of + # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is + # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus, + # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304 + self.projection_input_dim = num_channels * latent_size + else: + self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def stack_history_torch(self, obs, history_length): + """ + 对输入观测值 `obs`(形状 [T, 3, 64, 64])进行转换, + 将每个时间步变为堆叠了前 history_length 帧的观测数据, + 如果前面的历史不足 history_length 则使用零填充。 + + 最终返回的 obs_with_history 形状为 [T, 3*history_length, 64, 64]. + + 参数: + obs: PyTorch tensor,形状为 [T, 3, H, W] (例如 H = W = 64) + history_length: 要堆叠的历史步数 + + 返回: + obs_with_history: PyTorch tensor,形状为 [T, 3*history_length, H, W] + """ + T, C, H, W = obs.shape + # seq_batch_size = 3 + seq_batch_size = int(T/(self.num_unroll_steps + self.history_length + 1)) + + # import ipdb;ipdb.set_trace() + + # Step 1: 重构 shape 为 [seq_batch_size, (num_unroll_steps+history_length), 3, 64, 64] + obs = obs.view(seq_batch_size, self.num_unroll_steps + self.history_length + 1, 3, H, W) + # 此时 obs.shape = [3, 7, 3, 64, 64] + + # import torch + # import torchvision.utils as vutils + # import matplotlib.pyplot as plt + # import numpy as np + # from PIL import Image + # # 假设 obs 是一个形状为 [2, 10, 3, 64, 64] 的tensor,且像素值范围为 [0, 1] + # # 示例:随机数据(实际使用时请替换为你的 obs) + # # 将 [2, 10, 3, 64, 64] 转换成 [20, 3, 64, 64] 保持序列顺序,这样每10个图片为一行 + # obs_flat = obs.reshape(-1, *obs.shape[2:]) + # # 使用 torchvision.utils.make_grid 拼接图像,设置 nrow=10 表示每行10个 + # grid = vutils.make_grid(obs_flat, nrow=10, padding=2) + # # 将 tensor 转为 numpy 数组,并调整维度顺序以适应 PIL 显示要求(H, W, C) + # np_grid = grid.permute(1, 2, 0).cpu().numpy() + # # 如果 tensor 的像素值已经在 [0, 1],可以乘以 255 并转为 uint8 类型 + # np_grid = (np_grid * 255).astype(np.uint8) + # # 保存图像到文件 + # Image.fromarray(np_grid).save('sequence_grid.png') + # # 显示图像 + # plt.figure(figsize=(8, 4)) + # plt.imshow(np_grid) + # plt.axis('off') + # plt.title("2行,每行10个timestep的序列图像") + # plt.show() + + # Step 2: 对时间维度应用 sliding window 操作(unfold); + # unfolding 参数: 在 dim=1 上,窗口大小为 history_length,步长为 1. + # unfolding 后形状:[seq_batch_size, (7 - history_length + 1), history_length, 3, 64, 64] + # observation_array = np.array(observation_list) + + windows = obs.unfold(dimension=1, size=self.history_length, step=1) # 形状:[3, 6, 3, 64, 64, 2] + # print("Step 2 windows.shape:", windows.shape) + + # windows.shape torch.Size([3, 7, 3, 64, 64, 2]) -> [3, 7+self.history_length-1, 3, 64, 64, 2] 请在前面补零,补齐为后者的维度 + # 计算需要补零的数量(在前面补上 history_length - 1 个零) + pad_len = self.history_length - 1 + # 构造与 windows 除待补维度外其他维度相同的补零张量,其形状为 [3, pad_len, 3, 64, 64, 2] + padding_tensor = torch.zeros( + (windows.size(0), pad_len, windows.size(2), windows.size(3), windows.size(4), windows.size(5)), + dtype=windows.dtype, + device=windows.device + ) + # 在维度 1 上拼接补零张量 + windows_padded = torch.cat([padding_tensor, windows], dim=1) + + # Step 4: 将窗口中的观测在通道维度上进行拼接 + # 原本每个窗口形状为 [2, 3, 64, 64],将 2 (history_length) 个通道拼接后变为 [6, 64, 64] + # 整体结果 shape 最终为 [seq_batch_size, num_unroll_steps, history_length*3, 64, 64] = [3, 5, 6, 64, 64] + windows_padded = windows_padded.reshape(seq_batch_size, self.num_unroll_steps+self.history_length + 1, history_length * 3, H, W) + + obs_with_history = windows_padded.view(-1, self.history_length * 3, H, W) + + + return obs_with_history + + def initial_inference(self, obs: torch.Tensor, action_batch=None, current_obs_batch=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of MuZero model, which is the first step of the MuZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs (:obj:`torch.Tensor`): The 2D image observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + batch_size = obs.size(0) + + # import ipdb;ipdb.set_trace() + + if self.training or action_batch is None: + # train phase + # import ipdb;ipdb.set_trace() + if action_batch is None and obs.shape[1] != 3*self.history_length: + # ======train phase: compute target value ======= + # 已知seq_batch_size=3, self.num_unroll_steps=5, self.history_length=2, + # 已知目前obs.shape == [seq_batch_size,(self.num_unroll_steps+self.history_length),3,64,64] 请先变换为 -> [seq_batch_size,(self.num_unroll_steps+self.history_length),3,64,64] + # 例如[21, 3, 64, 64] -> [3, 7, 3, 64, 64] + # 对于其中每个序列,[i,7, 3,64,64], + # 取self.history_length之后的时间步,每个时间步都保留前面self.history_length的ob, + # 即变为[i,7-self.history_length, 3*self.history_length,64,64] = [i,5,6,64,64] + # 总的数据变换过程为 [21, 3, 64, 64] -> [3*7, 3, 64, 64] -> [3, 7, 3, 64, 64]-> [3, 6, 6, 64, 64] + obs_with_history = self.stack_history_torch(obs, self.history_length) + # print(f"train phase (compute target value) obs_with_history.shape:{obs_with_history.shape}") + + else: + # ======= train phase: init_infer ======= + obs_with_history = obs + # print(f"train phase (init inference) obs_with_history.shape:{obs_with_history.shape}") + + assert obs_with_history.shape[1] == 3*self.history_length + # TODO(pu) + self.latent_state = self.representation_network(obs_with_history) + + self.timestep = 0 + else: + # print(f"collect/eval phase obs_with_history.shape:{obs.shape}") + # ======== collect/eval phase ======== + obs_with_history = obs + + # ===== obs: torch.Tensor, action_batch=None, current_obs_batch=None + assert obs_with_history.shape[1] == 3*self.history_length + # TODO(pu) + self.latent_state = self.representation_network(obs_with_history) + # print(f"collect/eval phase latent_state.shape:{self.latent_state.shape}") + + + # import ipdb;ipdb.set_trace() + + policy_logits, value = self.prediction_network(self.latent_state) + self.timestep += 1 + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + self.latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward``, by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current + ``latent_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + if self.vector_ynamics_network: + next_latent_state, reward = self._dynamics_vector(latent_state, action) + else: + next_latent_state, reward = self._dynamics(latent_state, action) + + policy_logits, value = self.prediction_network(next_latent_state) + self.latent_state = next_latent_state # NOTE: update latent_state + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _dynamics_vector(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``reward`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_encoding_dim: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 64, + reward_head_hidden_channels: SequenceType = [32], + output_support_size: int = 601, + flatten_input_size_for_reward_head: int = 64, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_sim_norm: bool = False, + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state and + reward given current latent state and action. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of input observation, e.g., (12, 96, 96). + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of input, including obs and action encoding. + - reward_head_channels (:obj:`int`): The channels of reward head. + - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - flatten_input_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \ + the input size of reward head. + - downsample (:obj:`bool`): Whether to downsample the input observation, default set it to False. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializationss for the last layer of \ + reward mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.flatten_input_size_for_reward_head = flatten_input_size_for_reward_head + + self.action_encoding_dim = action_encoding_dim + self.conv = nn.Conv2d(num_channels, num_channels - self.action_encoding_dim, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm_common = nn.BatchNorm2d(num_channels - self.action_encoding_dim) + elif norm_type == 'LN': + if downsample: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, observation_shape[-2], observation_shape[-1]]) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels - self.action_encoding_dim, activation=activation, norm_type='BN', res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_reward = nn.Conv2d(num_channels - self.action_encoding_dim, reward_head_channels, 1) + + if norm_type == 'BN': + self.norm_reward = nn.BatchNorm2d(reward_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_reward = nn.LayerNorm([reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_reward = nn.LayerNorm([reward_head_channels, observation_shape[-2], observation_shape[-1]]) + + self.fc_reward_head = MLP( + self.flatten_input_size_for_reward_head, + hidden_channels=reward_head_hidden_channels[0], + layer_num=len(reward_head_hidden_channels) + 1, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.activation = activation + self.use_sim_norm = use_sim_norm + if self.use_sim_norm: + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \ + height, width). + - reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). + """ + # take the state encoding, state_action_encoding[:, -self.action_encoding_dim:, :, :] is action encoding + state_encoding = state_action_encoding[:, :-self.action_encoding_dim:, :, :] + x = self.conv(state_action_encoding) + x = self.norm_common(x) + + # the residual link: add state encoding to the state_action encoding + x += state_encoding + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + x = self.conv1x1_reward(next_latent_state) + x = self.norm_reward(x) + x = self.activation(x) + x = x.view(-1, self.flatten_input_size_for_reward_head) + + # use the fully connected layer to predict reward + reward = self.fc_reward_head(x) + + if self.use_sim_norm: + next_latent_state = self.sim_norm(next_latent_state) + + return next_latent_state, reward diff --git a/lzero/model/muzero_model_history_bkp20250320.py b/lzero/model/muzero_model_history_bkp20250320.py new file mode 100644 index 000000000..e6e39682b --- /dev/null +++ b/lzero/model/muzero_model_history_bkp20250320.py @@ -0,0 +1,853 @@ +""" +Overview: + BTW, users can refer to the unittest of these model templates to learn how to use them. +""" +from typing import Optional, Tuple + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray +import math +from typing import Sequence, Tuple, List +import torch.nn.init as init + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .common import MZNetworkOutput, PredictionNetwork, FeatureAndGradientHook, MLP_V2, DownSample, SimNorm +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from lzero.model.muzero_model import MuZeroModel +from lzero.model.muzero_model_mlp import DynamicsNetworkVector, PredictionNetworkMLP + + +class RepresentationNetworkMemoryEnv(nn.Module): + def __init__( + self, + image_shape: Sequence = (3, 5, 5), # 单步输入 shape,每一步的 channel 为 image_shape[0] + embedding_size: int = 100, + channels: List[int] = [16, 32, 64], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], + activation: nn.Module = nn.GELU(approximate='tanh'), + normalize_pixel: bool = False, + group_size: int = 8, + fusion_mode: str = 'mean', # 当前仅支持均值融合,后续可扩展为其它融合方式 + **kwargs, + ): + """ + 表征网络,用于 MemoryEnv,将2D图像 obs 编码为 latent state,并支持对多历史步进行融合。 + 除了对单步图像进行编码(如 image_shape 为 (3, 5, 5)),本网络扩展为: + 1. 根据输入通道数(total_channels)与单步输入通道数(image_shape[0])的比值,划分为多个历史步, + 即输入 x 的 shape 为 [B, total_channels, W, H],其中 total_channels 应为 (history_length * image_shape[0])。 + 2. 分别编码每一步,输出 latent series,形状为 [B, history_length, embedding_size]。 + 3. 根据 fusion_mode 对 history_length 个 latent 进行融合,得到最终 latent state,形状为 [B, embedding_size]。 + """ + super(RepresentationNetworkMemoryEnv, self).__init__() + self.image_shape = image_shape + self.single_step_in_channels = image_shape[0] + self.embedding_size = embedding_size + self.normalize_pixel = normalize_pixel + self.fusion_mode = fusion_mode + + # 构建单步 CNN encoder 网络(和 LatentEncoderForMemoryEnv 保持一致的基本结构) + self.channels = [image_shape[0]] + list(channels) + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.Conv2d( + in_channels=self.channels[i], + out_channels=self.channels[i + 1], + kernel_size=kernel_sizes[i], + stride=strides[i], + padding=kernel_sizes[i] // 2, # 保持 feature map 大小不变 + ) + ) + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + # 自适应池化,输出形状固定为 1x1 + layers.append(nn.AdaptiveAvgPool2d(1)) + self.cnn = nn.Sequential(*layers) + + # 全连接层将 CNN 输出转为 embedding 表征 + self.linear = nn.Linear(self.channels[-1], embedding_size, bias=False) + init.kaiming_normal_(self.linear.weight, mode='fan_out', nonlinearity='relu') + + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward_single(self, x: torch.Tensor) -> torch.Tensor: + """ + 对单步输入进行编码: + x: [B, single_step_in_channels, W, H] + 返回: [B, embedding_size] + """ + if self.normalize_pixel: + x = x / 255.0 + x = self.cnn(x.float()) # 输出形状 (B, C, 1, 1) + # import ipdb;ipdb.set_trace() + + x = torch.flatten(x, start_dim=1) # 转换为形状 (B, C) + x = self.linear(x) # (B, embedding_size) + x = self.sim_norm(x) # 归一化处理 + return x + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: 输入 tensor,形状 [B, total_channels, W, H],其中 total_channels 应为 (history_length * single_step_in_channels) + 例如输入 shape 可为: [8, 12, 5, 5],其中 12 = 4 * 3,表示 4 个历史步,每步 3 个 channel。 + Returns: + latent_series: 各步的 latent 表征,形状为 [B, history_length, embedding_size] + latent_fused: 融合后的 latent 表征,形状为 [B, embedding_size] + """ + B, total_channels, W, H = x.shape + if total_channels % self.single_step_in_channels != 0: + raise ValueError( + f"总通道数 {total_channels} 不能整除单步通道数 {self.single_step_in_channels}" + ) + history_length = total_channels // self.single_step_in_channels + + latent_series = [] + for t in range(history_length): + # 取第 t 个历史步的数据 + x_t = x[:, t * self.single_step_in_channels:(t + 1) * self.single_step_in_channels, :, :] + latent_t = self.forward_single(x_t) # [B, embedding_size] + latent_series.append(latent_t.unsqueeze(1)) # 在时间维度上扩展 + + latent_series = torch.cat(latent_series, dim=1) # [B, history_length, embedding_size] + + + # 根据 fusion_mode 对所有历史步进行融合 + if self.fusion_mode == 'mean': + latent_fused = latent_series.mean(dim=1) + # import ipdb;ipdb.set_trace() + else: + # 其它融合方式:例如先拼接后通过全连接层融合 + B, T, E = latent_series.shape + latent_concat = latent_series.view(B, -1) # [B, T * E] + fusion_fc = nn.Linear(T * E, E).to(x.device) + latent_fused = fusion_fc(latent_concat) + + # return latent_series, latent_fused + return latent_fused + + +# 修改后的扩展版本的 RepresentationNetwork +class RepresentationNetwork(nn.Module): + def __init__( + self, + observation_shape: Sequence = (3, 64, 64), # 单步输入 shape, 每一步3个channel + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_sim_norm: bool = False, + fusion_mode: str = 'mean', # 可以扩展为其它融合方式 + ) -> None: + """ + 表征网络,将2D图像 obs 编码为 latent state。 + + 除了本来的单步编码(例如 obs_with_history[:,:3,:,:]),该网络扩展为: + 1. 根据输入的第二维(通道维度)划分为多个历史步(每步3个 channel)。 + 2. 分别计算每一步的 latent state,输出 shape 为 [B, T, num_channels, H_out, W_out]。 + 3. 将 T 步的信息融合(例如均值融合)得到最终 latent state,其 shape 为 [B, num_channels, H_out, W_out]。 + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + # 这里单步输入channels为 observation_shape[0],一般设置为 3 + self.single_step_in_channels = observation_shape[0] + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(self.single_step_in_channels, num_channels, kernel_size=3, stride=1, padding=1, bias=False) + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + self.use_sim_norm = use_sim_norm + + if self.use_sim_norm: + self.embedding_dim = embedding_dim + self.sim_norm = SimNorm(simnorm_dim=group_size) + + # 融合模式,当前仅支持均值融合;可以扩展为其它方式,例如使用 1D 卷积融合时间步信息 + self.fusion_mode = fusion_mode + + def forward_single(self, x: torch.Tensor) -> torch.Tensor: + """ + 处理单步输入: + x: [B, single_step_in_channels, W, H] + 返回: [B, num_channels, W_out, H_out] + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + + if self.use_sim_norm: + x = self.sim_norm(x) + + return x + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: 输入 tensor,shape 为 [B, total_channels, W, H],其中 total_channels 应为 (history_length * single_step_in_channels)。 + 例如 collect/eval 阶段的输入 shape: [8, 12, 64, 64],其中 12 = 4 * 3,表示 4 个历史步,每步 3 个 channel。 + + Returns: + latent_series: 各步的 latent state,shape 为 [B, history_length, num_channels, W_out, H_out] + latent_fused: 融合后的 latent state,shape 为 [B, num_channels, W_out, H_out] + """ + B, total_channels, W, H = x.shape + assert total_channels % self.single_step_in_channels == 0, ( + f"Total channels {total_channels} 不能整除单步通道数 {self.single_step_in_channels}" + ) + history_length = total_channels // self.single_step_in_channels + + latent_series = [] + for t in range(history_length): + # 对应第 t 步的数据:取第 t*channel 到 (t+1)*channel + x_t = x[:, t * self.single_step_in_channels:(t + 1) * self.single_step_in_channels, :, :] + latent_t = self.forward_single(x_t) # [B, num_channels, W_out, H_out] + latent_series.append(latent_t.unsqueeze(1)) # 在时间维度上扩展 + + latent_series = torch.cat(latent_series, dim=1) # [B, history_length, num_channels, W_out, W_out] + + # import ipdb;ipdb.set_trace() + + # 根据 fusion_mode 融合历史步信息 + if self.fusion_mode == 'mean': + latent_fused = latent_series.mean(dim=1) # 均值融合, [B, num_channels, W_out, H_out] + else: + # 可增加其它融合方式,比如拼接后通过1x1卷积 + B, T, C, H_out, W_out = latent_series.shape + latent_concat = latent_series.view(B, -1, H_out, W_out) # [B, T * C, H_out, W_out] + fusion_conv = nn.Conv2d(T * C, C, kernel_size=1) + latent_fused = fusion_conv(latent_concat) + + # return latent_series, latent_fused + return latent_fused + + + +# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. +@MODEL_REGISTRY.register('MuZeroHistoryModel') +class MuZeroHistoryModel(MuZeroModel): + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + reward_head_hidden_channels: SequenceType = [32], + value_head_hidden_channels: SequenceType = [32], + policy_head_hidden_channels: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + history_length: int = 5, + num_unroll_steps: int = 5, + use_sim_norm: bool = False, + analysis_sim_norm: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the model for MuZero w/ Context, a variant of MuZero. + This variant retains the same training settings as MuZero but diverges during inference + by employing a k-step recursively predicted latent representation at the root node, + proposed in the UniZero paper https://arxiv.org/abs/2406.10667. + Arguments: + - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - reward_head_channels (:obj:`int`): The channels of reward head. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + in MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical \ + distribution for value and reward. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for hidden states, default set it to False. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + """ + super(MuZeroHistoryModel, self).__init__() + + self.timestep = 0 + self.history_length = history_length # NOTE + self.num_unroll_steps = num_unroll_steps + + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + # to be compatible with LightZero model/policy, transform to shape: [C, W, H] + observation_shape = [1, observation_shape, 1] + + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + self.action_space_size = action_space_size + print('action_space_size:', action_space_size) + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.downsample = downsample + self.analysis_sim_norm = analysis_sim_norm + + if observation_shape[1] == 96: + latent_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) + elif observation_shape[1] == 64: + latent_size = math.ceil(observation_shape[1] / 8) * math.ceil(observation_shape[2] / 8) + elif observation_shape[1] == 5: + latent_size = 64 + + + flatten_input_size_for_reward_head = ( + (reward_head_channels * latent_size) if downsample else + (reward_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_input_size_for_value_head = ( + (value_head_channels * latent_size) if downsample else + (value_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_input_size_for_policy_head = ( + (policy_head_channels * latent_size) if downsample else + (policy_head_channels * observation_shape[1] * observation_shape[2]) + ) + + if observation_shape[1] == 5: + # MemoryEnv + embedding_size = 768 + self.representation_network = RepresentationNetworkMemoryEnv( + observation_shape, + embedding_size=embedding_size, + channels= [16, 32, 64], + group_size= 8, + ) + self.num_channels = num_channels + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.dynamics_network = DynamicsNetworkVector( + action_encoding_dim=self.action_encoding_dim, + num_channels=embedding_size + self.action_encoding_dim, + common_layer_num=2, + reward_head_hidden_channels=reward_head_hidden_channels, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=True, + ) + self.vector_ynamics_network = True + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=embedding_size, + value_head_hidden_channels=value_head_hidden_channels, + policy_head_hidden_channels=policy_head_hidden_channels, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + else: + # atari + self.representation_network = RepresentationNetwork( + observation_shape, + num_res_blocks, + num_channels, + downsample, + activation=activation, + norm_type=norm_type, + embedding_dim=768, + group_size=8, + use_sim_norm=use_sim_norm, # NOTE + ) + # ====== for analysis ====== + if self.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.dynamics_network = DynamicsNetwork( + observation_shape, + self.action_encoding_dim, + num_res_blocks, + num_channels + self.action_encoding_dim, + reward_head_channels, + reward_head_hidden_channels, + self.reward_support_size, + flatten_input_size_for_reward_head, + downsample, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type, + embedding_dim=768, + group_size=8, + use_sim_norm=use_sim_norm, # NOTE + ) + self.vector_ynamics_network = False + + self.prediction_network = PredictionNetwork( + observation_shape, + action_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + value_head_hidden_channels, + policy_head_hidden_channels, + self.value_support_size, + flatten_input_size_for_value_head, + flatten_input_size_for_policy_head, + downsample, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # projection used in EfficientZero + if self.downsample: + # In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of + # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is + # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus, + # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304 + self.projection_input_dim = num_channels * latent_size + else: + self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def stack_history_torch(self, obs, history_length): + """ + 对输入观测值 `obs`(形状 [T, 3, 64, 64])进行转换, + 将每个时间步变为堆叠了前 history_length 帧的观测数据, + 如果前面的历史不足 history_length 则使用零填充。 + + 最终返回的 obs_with_history 形状为 [T, 3*history_length, 64, 64]. + + 参数: + obs: PyTorch tensor,形状为 [T, 3, H, W] (例如 H = W = 64) + history_length: 要堆叠的历史步数 + + 返回: + obs_with_history: PyTorch tensor,形状为 [T, 3*history_length, H, W] + """ + T, C, H, W = obs.shape + # seq_batch_size = 3 + seq_batch_size = int(T/(self.num_unroll_steps + self.history_length + 1)) + + # import ipdb;ipdb.set_trace() + + # Step 1: 重构 shape 为 [seq_batch_size, (num_unroll_steps+history_length), 3, 64, 64] + obs = obs.view(seq_batch_size, self.num_unroll_steps + self.history_length + 1, 3, H, W) + # 此时 obs.shape = [3, 7, 3, 64, 64] + + # Step 2: 对时间维度应用 sliding window 操作(unfold); + # unfolding 参数: 在 dim=1 上,窗口大小为 history_length,步长为 1. + # unfolding 后形状:[seq_batch_size, (7 - history_length + 1), history_length, 3, 64, 64] + windows = obs.unfold(dimension=1, size=self.history_length, step=1) # 形状:[3, 6, 3, 64, 64, 2] + # print("Step 2 windows.shape:", windows.shape) + + # windows.shape torch.Size([3, 7, 3, 64, 64, 2]) -> [3, 7+self.history_length-1, 3, 64, 64, 2] 请在前面补零,补齐为后者的维度 + # 计算需要补零的数量(在前面补上 history_length - 1 个零) + pad_len = self.history_length - 1 + # 构造与 windows 除待补维度外其他维度相同的补零张量,其形状为 [3, pad_len, 3, 64, 64, 2] + padding_tensor = torch.zeros( + (windows.size(0), pad_len, windows.size(2), windows.size(3), windows.size(4), windows.size(5)), + dtype=windows.dtype, + device=windows.device + ) + # 在维度 1 上拼接补零张量 + windows_padded = torch.cat([padding_tensor, windows], dim=1) + + # Step 4: 将窗口中的观测在通道维度上进行拼接 + # 原本每个窗口形状为 [2, 3, 64, 64],将 2 (history_length) 个通道拼接后变为 [6, 64, 64] + # 整体结果 shape 最终为 [seq_batch_size, num_unroll_steps, history_length*3, 64, 64] = [3, 5, 6, 64, 64] + windows_padded = windows_padded.reshape(seq_batch_size, self.num_unroll_steps+self.history_length + 1, history_length * 3, H, W) + + obs_with_history = windows_padded.view(-1, self.history_length * 3, H, W) + + + return obs_with_history + + def initial_inference(self, obs: torch.Tensor, action_batch=None, current_obs_batch=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of MuZero model, which is the first step of the MuZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs (:obj:`torch.Tensor`): The 2D image observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + batch_size = obs.size(0) + + # import ipdb;ipdb.set_trace() + + if self.training or action_batch is None: + # train phase + # import ipdb;ipdb.set_trace() + if action_batch is None and obs.shape[1] != 3*self.history_length: + # ======train phase: compute target value ======= + # 已知seq_batch_size=3, self.num_unroll_steps=5, self.history_length=2, + # 已知目前obs.shape == [seq_batch_size,(self.num_unroll_steps+self.history_length),3,64,64] 请先变换为 -> [seq_batch_size,(self.num_unroll_steps+self.history_length),3,64,64] + # 例如[21, 3, 64, 64] -> [3, 7, 3, 64, 64] + # 对于其中每个序列,[i,7, 3,64,64], + # 取self.history_length之后的时间步,每个时间步都保留前面self.history_length的ob, + # 即变为[i,7-self.history_length, 3*self.history_length,64,64] = [i,5,6,64,64] + # 总的数据变换过程为 [21, 3, 64, 64] -> [3*7, 3, 64, 64] -> [3, 7, 3, 64, 64]-> [3, 6, 6, 64, 64] + obs_with_history = self.stack_history_torch(obs, self.history_length) + # print(f"train phase (compute target value) obs_with_history.shape:{obs_with_history.shape}") + + else: + # ======= train phase: init_infer ======= + obs_with_history = obs + # print(f"train phase (init inference) obs_with_history.shape:{obs_with_history.shape}") + + assert obs_with_history.shape[1] == 3*self.history_length + # TODO(pu) + self.latent_state = self.representation_network(obs_with_history) + + self.timestep = 0 + else: + # print(f"collect/eval phase obs_with_history.shape:{obs.shape}") + # ======== collect/eval phase ======== + obs_with_history = obs + + # ===== obs: torch.Tensor, action_batch=None, current_obs_batch=None + assert obs_with_history.shape[1] == 3*self.history_length + # TODO(pu) + self.latent_state = self.representation_network(obs_with_history) + # print(f"collect/eval phase latent_state.shape:{self.latent_state.shape}") + + + # import ipdb;ipdb.set_trace() + + policy_logits, value = self.prediction_network(self.latent_state) + self.timestep += 1 + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + self.latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward``, by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current + ``latent_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + if self.vector_ynamics_network: + next_latent_state, reward = self._dynamics_vector(latent_state, action) + else: + next_latent_state, reward = self._dynamics(latent_state, action) + + policy_logits, value = self.prediction_network(next_latent_state) + self.latent_state = next_latent_state # NOTE: update latent_state + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _dynamics_vector(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``reward`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_encoding_dim: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 64, + reward_head_hidden_channels: SequenceType = [32], + output_support_size: int = 601, + flatten_input_size_for_reward_head: int = 64, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_sim_norm: bool = False, + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state and + reward given current latent state and action. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of input observation, e.g., (12, 96, 96). + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of input, including obs and action encoding. + - reward_head_channels (:obj:`int`): The channels of reward head. + - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - flatten_input_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \ + the input size of reward head. + - downsample (:obj:`bool`): Whether to downsample the input observation, default set it to False. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializationss for the last layer of \ + reward mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.flatten_input_size_for_reward_head = flatten_input_size_for_reward_head + + self.action_encoding_dim = action_encoding_dim + self.conv = nn.Conv2d(num_channels, num_channels - self.action_encoding_dim, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm_common = nn.BatchNorm2d(num_channels - self.action_encoding_dim) + elif norm_type == 'LN': + if downsample: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, observation_shape[-2], observation_shape[-1]]) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels - self.action_encoding_dim, activation=activation, norm_type='BN', res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_reward = nn.Conv2d(num_channels - self.action_encoding_dim, reward_head_channels, 1) + + if norm_type == 'BN': + self.norm_reward = nn.BatchNorm2d(reward_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_reward = nn.LayerNorm([reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_reward = nn.LayerNorm([reward_head_channels, observation_shape[-2], observation_shape[-1]]) + + self.fc_reward_head = MLP( + self.flatten_input_size_for_reward_head, + hidden_channels=reward_head_hidden_channels[0], + layer_num=len(reward_head_hidden_channels) + 1, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.activation = activation + self.use_sim_norm = use_sim_norm + if self.use_sim_norm: + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \ + height, width). + - reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). + """ + # take the state encoding, state_action_encoding[:, -self.action_encoding_dim:, :, :] is action encoding + state_encoding = state_action_encoding[:, :-self.action_encoding_dim:, :, :] + x = self.conv(state_action_encoding) + x = self.norm_common(x) + + # the residual link: add state encoding to the state_action encoding + x += state_encoding + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + x = self.conv1x1_reward(next_latent_state) + x = self.norm_reward(x) + x = self.activation(x) + x = x.view(-1, self.flatten_input_size_for_reward_head) + + # use the fully connected layer to predict reward + reward = self.fc_reward_head(x) + + if self.use_sim_norm: + next_latent_state = self.sim_norm(next_latent_state) + + return next_latent_state, reward diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index 01f6924b9..39be0eec5 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -105,7 +105,7 @@ def __init__( observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type ) - self.dynamics_network = DynamicsNetwork( + self.dynamics_network = DynamicsNetworkVector( action_encoding_dim=self.action_encoding_dim, num_channels=self.latent_state_dim + self.action_encoding_dim, common_layer_num=2, @@ -325,7 +325,7 @@ def get_params_mean(self) -> float: return get_params_mean(self) -class DynamicsNetwork(nn.Module): +class DynamicsNetworkVector(nn.Module): def __init__( self, diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 62536c892..221f81cd8 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -29,6 +29,8 @@ class TransformerConfig: resid_pdrop: float attn_pdrop: float + gru_gating: bool + @property def max_tokens(self): return self.tokens_per_block * self.max_blocks diff --git a/lzero/policy/muzero_history.py b/lzero/policy/muzero_history.py new file mode 100644 index 000000000..68ef3b6b3 --- /dev/null +++ b/lzero/policy/muzero_history.py @@ -0,0 +1,1133 @@ +import copy +from typing import List, Dict, Any, Tuple, Union, Optional + +import numpy as np +import torch +import torch.optim as optim +import wandb +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.nn import L1Loss + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts import MuZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.model.utils import cal_dormant_ratio +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs_history, configure_optimizers + + +@POLICY_REGISTRY.register('muzero_history') +class MuZeroHistoryPolicy(Policy): + """ + Overview: + if self._cfg.model.model_type in ["conv", "mlp"]: + The policy class for MuZero. + if self._cfg.model.model_type == ["conv_context", "mlp_context"]: + The policy class for MuZero w/ Context, a variant of MuZero. + This variant retains the same training settings as MuZero but diverges during inference + by employing a k-step recursively predicted latent representation at the root node, + proposed in the UniZero paper https://arxiv.org/abs/2406.10667. + """ + + # The default_config for MuZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=False, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + # reference: http://proceedings.mlr.press/v80/imani18a/imani18a.pdf, https://arxiv.org/abs/2403.03950 + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (bool) Whether to use HarmonyDream to balance weights between different losses. Default to False. + # More details can be found in https://arxiv.org/abs/2310.00344. + harmony_balance=False + ), + # ****** common ****** + # (bool) Whether to use wandb to log the training process. + use_wandb=False, + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + # (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt). + # If set to True, the checkpoint will be evaluated after the training process is complete. + # IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency. + # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically. + eval_offline=False, + # (bool) Whether to calculate the dormant ratio. + cal_dormant_ratio=False, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of policy entropy loss. + policy_entropy_weight=0, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + piecewise_decay_lr_scheduler=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + # (bool) Whether to add noise to roots during reanalyze process. + reanalyze_noise=True, + # (bool) Whether to reuse the root value between batch searches. + reuse_search=False, + # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy. + collect_with_pure_policy=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel`` + """ + if self._cfg.model.model_type in ["conv_history"]: + return 'MuZeroHistoryModel', ['lzero.model.muzero_model_history'] + elif self._cfg.model.model_type == "mlp": + return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] + elif self._cfg.model.model_type in ["conv_context"]: + return 'MuZeroContextModel', ['lzero.model.muzero_context_model'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + + def set_train_iter_env_step(self, train_iter, env_step) -> None: + """ + Overview: + Set the train_iter and env_step for the policy. + Arguments: + - train_iter (:obj:`int`): The train_iter for the policy. + - env_step (:obj:`int`): The env_step for the policy. + """ + self.train_iter = train_iter + self.env_step = env_step + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + # ============================================================== + # harmonydream (learnable weights for different losses) + # ============================================================== + if self._cfg.model.harmony_balance: + # List of parameter names + harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + # Initialize and name each parameter + for name in harmony_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + # ========= logging for analysis ========= + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + self.dormant_ratio_encoder = 0. + self.dormant_ratio_dynamics = 0. + + if self._cfg.use_wandb: + # TODO: add the model to wandb + wandb.watch(self._learn_model.representation_network, log="all") + + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.train() + + current_batch, target_batch = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + # import ipdb;ipdb.set_trace() + # TODO: check SSL=False obs_target_batch = None + obs_batch, obs_target_batch = prepare_obs_history(obs_batch_ori, self._cfg) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long() is only for discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [mask_batch, target_reward, + target_value, target_policy, weights + ] + [mask_batch, target_reward, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + # TODO: check + # 第一步的latent state 对应采样的序列中的第self.history_length-1=2-1=步(序列从1开始计算),例如采样长度为5+2+1=8,第一步的latent state对应第2步(序列从1开始计算) + target_policy = target_policy[:,self.history_length-1:,:] + target_reward = target_reward[:,self.history_length:] + target_value = target_value[:,self.history_length-1:] + + assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in MuZero policy. + # ============================================================== + network_output = self._learn_model.initial_inference(obs_batch) + + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # ========= logging for analysis ========= + # calculate dormant ratio of encoder + if self._cfg.cal_dormant_ratio: + self.dormant_ratio_encoder = cal_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), + percentage=self._cfg.dormant_threshold) + # calculate L2 norm of latent state + latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() + # ========= logging for analysis =============== + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # Note: The following lines are just for debugging. + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * prob.log()).sum(-1) + policy_entropy_loss = -entropy + + reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + target_policy_entropy = 0 + + # ============================================================== + # the core recurrent_inference in MuZero policy. + # ============================================================== + for step_k in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, + # given current ``latent_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + # import ipdb;ipdb.set_trace() + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k+self.history_length-1]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # ========= logging for analysis =============== + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + # calculate dormant ratio of encoder + action_tmp = action_batch[:, step_k+self.history_length-1] + if len(action_tmp.shape) == 1: + action = action.unsqueeze(-1) + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) + # transform action to torch.int64 + action_tmp = action_tmp.long() + action_one_hot.scatter_(1, action_tmp, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] + ) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + self.dormant_ratio_dynamics = cal_dormant_ratio(self._learn_model.dynamics_network, + state_action_encoding.detach(), + percentage=self._cfg.dormant_threshold) + # ========= logging for analysis =============== + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + if self._cfg.model.self_supervised_learning_loss: + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in + # game buffer now. + # ============================================================== + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the +=. + # ============================================================== + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + + # Here we take the hypothetical step k = step_k + 1 + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * prob.log()).sum(-1) + policy_entropy_loss += -entropy + + target_normalized_visit_count = target_policy[:, step_k + 1] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -((target_normalized_visit_count_masked + 1e-6) * ( + target_normalized_visit_count_masked + 1e-6).log()).sum(-1).mean() + else: + # Set target_policy_entropy to log(|A|) if all rows are masked + target_policy_entropy += torch.log(torch.tensor(target_normalized_visit_count.shape[-1])) + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) + + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + # Nan appear when consistency loss or policy entropy loss uses harmony parameter as coefficient. + + # Please refer to https://github.com/thuml/HarmonyDream/blob/main/wmlib-torch/wmlib/agents/dreamerv2.py#L161 + # ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + if self._cfg.model.harmony_balance: + loss = ( + (consistency_loss.mean() * self._cfg.ssl_loss_weight) + + (policy_loss.mean() / torch.exp(self.harmony_policy)) + + (value_loss.mean() / torch.exp(self.harmony_value)) + + (reward_loss.mean() / torch.exp(self.harmony_reward)) + ) + weighted_total_loss = loss.mean() + weighted_total_loss += ( + torch.log(torch.exp(self.harmony_policy) + 1) + + torch.log(torch.exp(self.harmony_value) + 1) + + torch.log(torch.exp(self.harmony_reward) + 1) + ) + else: + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + + self._cfg.policy_entropy_weight * policy_entropy_loss + ) + weighted_total_loss = (weights * loss).mean() + + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + + # ============= for analysis ============= + if self._cfg.analysis_sim_norm: + del self.l2_norm_before + del self.l2_norm_after + del self.grad_norm_before + del self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + # ============= for analysis ============= + + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), + self._cfg.grad_clip_value) + self._optimizer.step() + if self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) + + if self._cfg.monitor_extra_statistics: + predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) + predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) + + return_log_dict = { + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), + 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'reward_loss': reward_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + 'target_reward': target_reward.mean().item(), + 'target_value': target_value.mean().item(), + 'transformed_target_reward': transformed_target_reward.mean().item(), + 'transformed_target_value': transformed_target_value.mean().item(), + 'predicted_rewards': predicted_rewards.mean().item(), + 'predicted_values': predicted_values.mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + # ============================================================== + # priority related + # ============================================================== + 'value_priority': value_priority.mean().item(), + 'value_priority_orig': value_priority, # torch.tensor compatible with ddp settings + + 'analysis/dormant_ratio_encoder': self.dormant_ratio_encoder, + 'analysis/dormant_ratio_dynamics': self.dormant_ratio_dynamics, + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/l2_norm_before': self.l2_norm_before, + 'analysis/l2_norm_after': self.l2_norm_after, + 'analysis/grad_norm_before': self.grad_norm_before, + 'analysis/grad_norm_after': self.grad_norm_after, + } + + # ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + if self._cfg.model.harmony_balance: + harmony_dict = { + "harmony_dynamics": self.harmony_dynamics.item(), + "harmony_dynamics_exp_recip": (1 / torch.exp(self.harmony_dynamics)).item(), + "harmony_policy": self.harmony_policy.item(), + "harmony_policy_exp_recip": (1 / torch.exp(self.harmony_policy)).item(), + "harmony_value": self.harmony_value.item(), + "harmony_value_exp_recip": (1 / torch.exp(self.harmony_value)).item(), + "harmony_reward": self.harmony_reward.item(), + "harmony_reward_exp_recip": (1 / torch.exp(self.harmony_reward)).item(), + "harmony_entropy": self.harmony_entropy.item(), + "harmony_entropy_exp_recip": (1 / torch.exp(self.harmony_entropy)).item(), + } + return_log_dict.update(harmony_dict) + + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_log_dict + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type in ["conv_context"]: + self.batch_obs_history_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] + if self._cfg.model.model_type in [ "conv_history"]: + self.batch_obs_history_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.batch_obs_with_history_ready_collect = self.batch_obs_history_collect + self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - epsilon: :math:`(1, )`. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + # if active_collect_env_num < self.collector_env_num: + # print(f"active_collect_env_num:{active_collect_env_num}") + # import ipdb;ipdb.set_trace() + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + + ready_env_ids = sorted(ready_env_id) + # 假设 data 的顺序与 ready_env_ids 对应,即 data[i] 为环境 ready_env_ids[i] 最新的观测。 + for idx, env_id in enumerate(ready_env_ids): + # self.last_batch_obs[env_id]: shape [total_channels, H, W] + # data[idx]: shape [num_obs_channels, H, W] + # 拼接后通道数为 total_channels + num_obs_channels + combined_obs = torch.cat([self.batch_obs_history_collect[env_id], data[idx]], dim=0) + # 仅保留最新的 total_channels 个通道 + self.batch_obs_history_collect[env_id] = combined_obs[-self.history_channels:] + # 从全局历史张量中取出当前 ready 环境对应的更新后的观测 + self.batch_obs_with_history_ready_collect = self.batch_obs_history_collect[ready_env_ids] + + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._collect_model.initial_inference(data) + elif self._cfg.model.model_type in ["conv_context", "conv_history"]: + network_output = self._collect_model.initial_inference(self.batch_obs_with_history_ready_collect, self.last_batch_action_collect, + data) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if not self._cfg.collect_with_pure_policy: + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + # if len(reward_roots) != len(policy_logits): + # import ipdb;ipdb.set_trace() + # if len(reward_roots) != len(noises): + # import ipdb;ipdb.set_trace() + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + self.last_batch_action_collect = batch_action + + else: + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), + dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type in ["conv_context"]: + self.batch_obs_history_eval = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_action_eval = [-1 for _ in range(self.evaluator_env_num)] + if self._cfg.model.model_type in [ "conv_history"]: + self.batch_obs_history_eval = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_obs_ready_eval = self.batch_obs_history_eval + self.last_batch_action_eval = [-1 for i in range(self.evaluator_env_num)] + + num_obs_channels = self._cfg.model.observation_shape[0] + self.history_length = self._cfg.model.history_length + self.history_channels = num_obs_channels * self.history_length + + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, ) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + # if active_eval_env_num < self.evaluator_env_num: + # print(f"active_eval_env_num:{active_eval_env_num}") + # import ipdb;ipdb.set_trace() + + with torch.no_grad(): + + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + # 先更新全局的 self.last_batch_obs: + # 对于 ready_env_id 中的每个环境,将最新的观测 data 拼接到之前的历史观测上,然后仅保留最后的 history 个时间步对应的通道。 + # 为了确保不同环境的顺序一致,先对 ready_env_id 排序(如果 ready_env_id 不是顺序递增的) + ready_env_ids = sorted(ready_env_id) + # 假设 data 的顺序与 ready_env_ids 对应,即 data[i] 为环境 ready_env_ids[i] 最新的观测。 + for idx, env_id in enumerate(ready_env_ids): + # self.last_batch_obs[env_id]: shape [total_channels, H, W] + # data[idx]: shape [num_obs_channels, H, W] + # 拼接后通道数为 total_channels + num_obs_channels + combined_obs = torch.cat([self.batch_obs_history_eval[env_id], data[idx]], dim=0) + + # 仅保留最新的 total_channels 个通道 + self.batch_obs_history_eval[env_id] = combined_obs[-self.history_channels:] + # 从全局历史张量中取出当前 ready 环境对应的更新后的观测 + self.last_batch_obs_ready_eval = self.batch_obs_history_eval[ready_env_ids] + + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._eval_model.initial_inference(data) + elif self._cfg.model.model_type in ["conv_context", "conv_history"]: + # 调用 initial_inference 时,传入更新后的 ready 环境观测; + # 注意:这里假定 self.last_batch_action 在对应模型中已经维护好(例如前一次记录的动作历史)。 + network_output = self._eval_model.initial_inference(self.last_batch_obs_ready_eval, self.last_batch_action_eval, data) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + self.last_batch_action_eval = batch_action + + + return output + + def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset the observation and action for the collector environment. + Arguments: + - data_id (`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + """ + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + # self.batch_obs_history_collect = initialize_zeros_batch( + # self._cfg.model.observation_shape, + # self._cfg.collector_env_num, + # self._cfg.device + # ) + self.batch_obs_history_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + + self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)] + else: + raise ValueError(f"Unsupported model type in collect: {self._cfg.model.model_type}") + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset the observation and action for the evaluator environment. + Arguments: + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + """ + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + # self.batch_obs_history_eval = initialize_zeros_batch( + # self._cfg.model.observation_shape, + # self._cfg.evaluator_env_num, + # self._cfg.device + # ) + self.batch_obs_history_eval = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)] + else: + raise ValueError(f"Unsupported model type in eval: {self._cfg.model.model_type}") + + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type in ['conv', 'conv_context', 'conv_history']: + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type in ['mlp', 'mlp_context', 'mlp_history']: + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return_list = [ + 'analysis/dormant_ratio_encoder', + 'analysis/dormant_ratio_dynamics', + 'analysis/latent_state_l2_norms', + 'analysis/l2_norm_before', + 'analysis/l2_norm_after', + 'analysis/grad_norm_before', + 'analysis/grad_norm_after', + + 'collect_mcts_temperature', + 'cur_lr', + 'weighted_total_loss', + 'total_loss', + 'policy_loss', + 'policy_entropy', + 'target_policy_entropy', + 'reward_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_reward', + 'target_value', + 'predicted_rewards', + 'predicted_values', + 'transformed_target_reward', + 'transformed_target_value', + 'total_grad_norm_before_clip', + ] + # ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + if self._cfg.model.harmony_balance: + harmony_list = [ + 'harmony_dynamics', 'harmony_dynamics_exp_recip', + 'harmony_policy', 'harmony_policy_exp_recip', + 'harmony_value', 'harmony_value_exp_recip', + 'harmony_reward', 'harmony_reward_exp_recip', + 'harmony_entropy', 'harmony_entropy_exp_recip', + ] + return_list.extend(harmony_list) + return return_list + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def __del__(self): + if self._cfg.model.analysis_sim_norm: + # Remove hooks after training. + self._collect_model.encoder_hook.remove_hooks() + self._target_model.encoder_hook.remove_hooks() + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass + diff --git a/lzero/policy/muzero_history_bkp.py b/lzero/policy/muzero_history_bkp.py new file mode 100644 index 000000000..312e7f82b --- /dev/null +++ b/lzero/policy/muzero_history_bkp.py @@ -0,0 +1,1131 @@ +import copy +from typing import List, Dict, Any, Tuple, Union, Optional + +import numpy as np +import torch +import torch.optim as optim +import wandb +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.nn import L1Loss + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts import MuZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.model.utils import cal_dormant_ratio +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs_history, configure_optimizers + + +@POLICY_REGISTRY.register('muzero_history') +class MuZeroHistoryPolicy(Policy): + """ + Overview: + if self._cfg.model.model_type in ["conv", "mlp"]: + The policy class for MuZero. + if self._cfg.model.model_type == ["conv_context", "mlp_context"]: + The policy class for MuZero w/ Context, a variant of MuZero. + This variant retains the same training settings as MuZero but diverges during inference + by employing a k-step recursively predicted latent representation at the root node, + proposed in the UniZero paper https://arxiv.org/abs/2406.10667. + """ + + # The default_config for MuZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=False, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + # reference: http://proceedings.mlr.press/v80/imani18a/imani18a.pdf, https://arxiv.org/abs/2403.03950 + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (bool) Whether to use HarmonyDream to balance weights between different losses. Default to False. + # More details can be found in https://arxiv.org/abs/2310.00344. + harmony_balance=False + ), + # ****** common ****** + # (bool) Whether to use wandb to log the training process. + use_wandb=False, + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + # (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt). + # If set to True, the checkpoint will be evaluated after the training process is complete. + # IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency. + # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically. + eval_offline=False, + # (bool) Whether to calculate the dormant ratio. + cal_dormant_ratio=False, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of policy entropy loss. + policy_entropy_weight=0, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + piecewise_decay_lr_scheduler=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + # (bool) Whether to add noise to roots during reanalyze process. + reanalyze_noise=True, + # (bool) Whether to reuse the root value between batch searches. + reuse_search=False, + # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy. + collect_with_pure_policy=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel`` + """ + if self._cfg.model.model_type in ["conv_history"]: + return 'MuZeroHistoryModel', ['lzero.model.muzero_model_history'] + elif self._cfg.model.model_type == "mlp": + return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] + elif self._cfg.model.model_type in ["conv_context"]: + return 'MuZeroContextModel', ['lzero.model.muzero_context_model'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + + def set_train_iter_env_step(self, train_iter, env_step) -> None: + """ + Overview: + Set the train_iter and env_step for the policy. + Arguments: + - train_iter (:obj:`int`): The train_iter for the policy. + - env_step (:obj:`int`): The env_step for the policy. + """ + self.train_iter = train_iter + self.env_step = env_step + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + # ============================================================== + # harmonydream (learnable weights for different losses) + # ============================================================== + if self._cfg.model.harmony_balance: + # List of parameter names + harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + # Initialize and name each parameter + for name in harmony_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + # ========= logging for analysis ========= + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + self.dormant_ratio_encoder = 0. + self.dormant_ratio_dynamics = 0. + + if self._cfg.use_wandb: + # TODO: add the model to wandb + wandb.watch(self._learn_model.representation_network, log="all") + + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.train() + + current_batch, target_batch = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + # import ipdb;ipdb.set_trace() + # TODO + obs_batch, obs_target_batch = prepare_obs_history(obs_batch_ori, self._cfg) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long() is only for discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [mask_batch, target_reward, + target_value, target_policy, weights + ] + [mask_batch, target_reward, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in MuZero policy. + # ============================================================== + network_output = self._learn_model.initial_inference(obs_batch) + + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # ========= logging for analysis ========= + # calculate dormant ratio of encoder + if self._cfg.cal_dormant_ratio: + self.dormant_ratio_encoder = cal_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), + percentage=self._cfg.dormant_threshold) + # calculate L2 norm of latent state + latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() + # ========= logging for analysis =============== + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # Note: The following lines are just for debugging. + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * prob.log()).sum(-1) + policy_entropy_loss = -entropy + + reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + target_policy_entropy = 0 + + # ============================================================== + # the core recurrent_inference in MuZero policy. + # ============================================================== + for step_k in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, + # given current ``latent_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k+self.history_length-1]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # ========= logging for analysis =============== + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + # calculate dormant ratio of encoder + action_tmp = action_batch[:, step_k+self.history_length-1] + if len(action_tmp.shape) == 1: + action = action.unsqueeze(-1) + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) + # transform action to torch.int64 + action_tmp = action_tmp.long() + action_one_hot.scatter_(1, action_tmp, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] + ) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + self.dormant_ratio_dynamics = cal_dormant_ratio(self._learn_model.dynamics_network, + state_action_encoding.detach(), + percentage=self._cfg.dormant_threshold) + # ========= logging for analysis =============== + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + if self._cfg.model.self_supervised_learning_loss: + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in + # game buffer now. + # ============================================================== + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the +=. + # ============================================================== + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + + # Here we take the hypothetical step k = step_k + 1 + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * prob.log()).sum(-1) + policy_entropy_loss += -entropy + + target_normalized_visit_count = target_policy[:, step_k + 1] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -((target_normalized_visit_count_masked + 1e-6) * ( + target_normalized_visit_count_masked + 1e-6).log()).sum(-1).mean() + else: + # Set target_policy_entropy to log(|A|) if all rows are masked + target_policy_entropy += torch.log(torch.tensor(target_normalized_visit_count.shape[-1])) + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) + + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + # Nan appear when consistency loss or policy entropy loss uses harmony parameter as coefficient. + + # Please refer to https://github.com/thuml/HarmonyDream/blob/main/wmlib-torch/wmlib/agents/dreamerv2.py#L161 + # ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + if self._cfg.model.harmony_balance: + loss = ( + (consistency_loss.mean() * self._cfg.ssl_loss_weight) + + (policy_loss.mean() / torch.exp(self.harmony_policy)) + + (value_loss.mean() / torch.exp(self.harmony_value)) + + (reward_loss.mean() / torch.exp(self.harmony_reward)) + ) + weighted_total_loss = loss.mean() + weighted_total_loss += ( + torch.log(torch.exp(self.harmony_policy) + 1) + + torch.log(torch.exp(self.harmony_value) + 1) + + torch.log(torch.exp(self.harmony_reward) + 1) + ) + else: + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + + self._cfg.policy_entropy_weight * policy_entropy_loss + ) + weighted_total_loss = (weights * loss).mean() + + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + + # ============= for analysis ============= + if self._cfg.analysis_sim_norm: + del self.l2_norm_before + del self.l2_norm_after + del self.grad_norm_before + del self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + # ============= for analysis ============= + + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), + self._cfg.grad_clip_value) + self._optimizer.step() + if self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) + + if self._cfg.monitor_extra_statistics: + predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) + predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) + + return_log_dict = { + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), + 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'reward_loss': reward_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + 'target_reward': target_reward.mean().item(), + 'target_value': target_value.mean().item(), + 'transformed_target_reward': transformed_target_reward.mean().item(), + 'transformed_target_value': transformed_target_value.mean().item(), + 'predicted_rewards': predicted_rewards.mean().item(), + 'predicted_values': predicted_values.mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + # ============================================================== + # priority related + # ============================================================== + 'value_priority': value_priority.mean().item(), + 'value_priority_orig': value_priority, # torch.tensor compatible with ddp settings + + 'analysis/dormant_ratio_encoder': self.dormant_ratio_encoder, + 'analysis/dormant_ratio_dynamics': self.dormant_ratio_dynamics, + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/l2_norm_before': self.l2_norm_before, + 'analysis/l2_norm_after': self.l2_norm_after, + 'analysis/grad_norm_before': self.grad_norm_before, + 'analysis/grad_norm_after': self.grad_norm_after, + } + + # ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + if self._cfg.model.harmony_balance: + harmony_dict = { + "harmony_dynamics": self.harmony_dynamics.item(), + "harmony_dynamics_exp_recip": (1 / torch.exp(self.harmony_dynamics)).item(), + "harmony_policy": self.harmony_policy.item(), + "harmony_policy_exp_recip": (1 / torch.exp(self.harmony_policy)).item(), + "harmony_value": self.harmony_value.item(), + "harmony_value_exp_recip": (1 / torch.exp(self.harmony_value)).item(), + "harmony_reward": self.harmony_reward.item(), + "harmony_reward_exp_recip": (1 / torch.exp(self.harmony_reward)).item(), + "harmony_entropy": self.harmony_entropy.item(), + "harmony_entropy_exp_recip": (1 / torch.exp(self.harmony_entropy)).item(), + } + return_log_dict.update(harmony_dict) + + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_log_dict + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] + if self._cfg.model.model_type in [ "conv_history"]: + self.last_batch_obs_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_obs_ready_collect = self.last_batch_obs_collect + self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - epsilon: :math:`(1, )`. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if active_collect_env_num < self.collector_env_num: + print(f"active_collect_env_num:{active_collect_env_num}") + # import ipdb;ipdb.set_trace() + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._collect_model.initial_inference(data) + elif self._cfg.model.model_type in ["conv_context", "conv_history"]: + network_output = self._collect_model.initial_inference(self.last_batch_obs_ready_collect, self.last_batch_action_collect, + data) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if not self._cfg.collect_with_pure_policy: + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + if len(reward_roots) != len(policy_logits): + import ipdb;ipdb.set_trace() + if len(reward_roots) != len(noises): + import ipdb;ipdb.set_trace() + + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_action_collect = batch_action + + # import ipdb;ipdb.set_trace() + # 先更新全局的 self.last_batch_obs: + # 对于 ready_env_id 中的每个环境,将最新的观测 data 拼接到之前的历史观测上,然后仅保留最后的 history 个时间步对应的通道。 + # 为了确保不同环境的顺序一致,先对 ready_env_id 排序(如果 ready_env_id 不是顺序递增的) + ready_env_ids = sorted(ready_env_id) + + # 假设 data 的顺序与 ready_env_ids 对应,即 data[i] 为环境 ready_env_ids[i] 最新的观测。 + for idx, env_id in enumerate(ready_env_ids): + # self.last_batch_obs[env_id]: shape [total_channels, H, W] + # data[idx]: shape [num_obs_channels, H, W] + # 拼接后通道数为 total_channels + num_obs_channels + combined_obs = torch.cat([self.last_batch_obs_collect[env_id], data[idx]], dim=0) + # 仅保留最新的 total_channels 个通道 + self.last_batch_obs_collect[env_id] = combined_obs[-self.history_channels:] + + # 从全局历史张量中取出当前 ready 环境对应的更新后的观测 + self.last_batch_obs_ready_collect = self.last_batch_obs_collect[ready_env_ids] + else: + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), + dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs_eval = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_action_eval = [-1 for _ in range(self.evaluator_env_num)] + if self._cfg.model.model_type in [ "conv_history"]: + self.last_batch_obs_eval = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_obs_ready_eval = self.last_batch_obs_eval + self.last_batch_action_eval = [-1 for i in range(self.evaluator_env_num)] + + num_obs_channels = self._cfg.model.observation_shape[0] + self.history_length = self._cfg.model.history_length + self.history_channels = num_obs_channels * self.history_length + + # elif self._cfg.model.model_type == 'mlp_context': + # self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape]).to(self._cfg.device) + # self.last_batch_action = [-1 for _ in range(3)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, ) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + if active_eval_env_num < self.evaluator_env_num: + print(f"active_eval_env_num:{active_eval_env_num}") + # import ipdb;ipdb.set_trace() + + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._eval_model.initial_inference(data) + elif self._cfg.model.model_type in ["conv_history"]: + # 调用 initial_inference 时,传入更新后的 ready 环境观测; + # 注意:这里假定 self.last_batch_action 在对应模型中已经维护好(例如前一次记录的动作历史)。 + network_output = self._eval_model.initial_inference(self.last_batch_obs_ready_eval, self.last_batch_action_eval, data) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_history"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_history"]: + self.last_batch_action_eval = batch_action + + # 先更新全局的 self.last_batch_obs: + # 对于 ready_env_id 中的每个环境,将最新的观测 data 拼接到之前的历史观测上,然后仅保留最后的 history 个时间步对应的通道。 + # 为了确保不同环境的顺序一致,先对 ready_env_id 排序(如果 ready_env_id 不是顺序递增的) + ready_env_ids = sorted(ready_env_id) + # 假设 data 的顺序与 ready_env_ids 对应,即 data[i] 为环境 ready_env_ids[i] 最新的观测。 + for idx, env_id in enumerate(ready_env_ids): + # self.last_batch_obs[env_id]: shape [total_channels, H, W] + # data[idx]: shape [num_obs_channels, H, W] + # 拼接后通道数为 total_channels + num_obs_channels + combined_obs = torch.cat([self.last_batch_obs_eval[env_id], data[idx]], dim=0) + + # 仅保留最新的 total_channels 个通道 + self.last_batch_obs_eval[env_id] = combined_obs[-self.history_channels:] + # 从全局历史张量中取出当前 ready 环境对应的更新后的观测 + self.last_batch_obs_ready_eval = self.last_batch_obs_eval[ready_env_ids] + + + return output + + def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset the observation and action for the collector environment. + Arguments: + - data_id (`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + """ + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + # self.last_batch_obs_collect = initialize_zeros_batch( + # self._cfg.model.observation_shape, + # self._cfg.collector_env_num, + # self._cfg.device + # ) + self.last_batch_obs_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + + self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)] + else: + raise ValueError(f"Unsupported model type in collect: {self._cfg.model.model_type}") + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset the observation and action for the evaluator environment. + Arguments: + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + """ + if self._cfg.model.model_type in ["conv_context", "conv_history"]: + # self.last_batch_obs_eval = initialize_zeros_batch( + # self._cfg.model.observation_shape, + # self._cfg.evaluator_env_num, + # self._cfg.device + # ) + self.last_batch_obs_eval = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0]*self._cfg.model.history_length, self._cfg.model.observation_shape[1],self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)] + else: + raise ValueError(f"Unsupported model type in eval: {self._cfg.model.model_type}") + + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type in ['conv', 'conv_context', 'conv_history']: + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type in ['mlp', 'mlp_context', 'mlp_history']: + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return_list = [ + 'analysis/dormant_ratio_encoder', + 'analysis/dormant_ratio_dynamics', + 'analysis/latent_state_l2_norms', + 'analysis/l2_norm_before', + 'analysis/l2_norm_after', + 'analysis/grad_norm_before', + 'analysis/grad_norm_after', + + 'collect_mcts_temperature', + 'cur_lr', + 'weighted_total_loss', + 'total_loss', + 'policy_loss', + 'policy_entropy', + 'target_policy_entropy', + 'reward_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_reward', + 'target_value', + 'predicted_rewards', + 'predicted_values', + 'transformed_target_reward', + 'transformed_target_value', + 'total_grad_norm_before_clip', + ] + # ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + if self._cfg.model.harmony_balance: + harmony_list = [ + 'harmony_dynamics', 'harmony_dynamics_exp_recip', + 'harmony_policy', 'harmony_policy_exp_recip', + 'harmony_value', 'harmony_value_exp_recip', + 'harmony_reward', 'harmony_reward_exp_recip', + 'harmony_entropy', 'harmony_entropy_exp_recip', + ] + return_list.extend(harmony_list) + return return_list + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def __del__(self): + if self._cfg.model.analysis_sim_norm: + # Remove hooks after training. + self._collect_model.encoder_hook.remove_hooks() + self._target_model.encoder_hook.remove_hooks() + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass + diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index f2cba7161..194e074d2 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -359,6 +359,55 @@ def prepare_obs_stack_for_unizero(obs_batch_ori: np.ndarray, cfg: EasyDict) -> T return obs_batch, obs_target_batch +def prepare_obs_history(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Prepare the observations for the model by converting the original batch of observations + to a PyTorch tensor, and then slicing it to create the batches used for the initial inference + and for calculating the consistency loss if required. + + Arguments: + - obs_batch_ori (:obj:`np.ndarray`): The original observations in a batch style. + - cfg (:obj:`EasyDict`): The configuration dictionary containing model settings. + + Returns: + - obs_batch (:obj:`torch.Tensor`): The tensor containing the observations for the initial inference. + - obs_target_batch (:obj:`torch.Tensor`): The tensor containing the observations for calculating + the consistency loss, if applicable. + """ + # 按照 history_length 切分 observation + history_length = cfg.model.history_length + # print(f"cfg.model.history_length:{history_length}") + + # Convert the numpy array of original observations to a PyTorch tensor and transfer it to the specified device. + # Also, ensure the tensor is of the correct floating-point type for the model. + obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device) + + # Calculate the dimension size to slice based on the model configuration. + # For convolutional models ('conv'), use the number of frames to stack times the number of channels. + # For multi-layer perceptron models ('mlp'), use the number of frames to stack times the size of the observation space. + stack_dim = cfg.model.frame_stack_num * ( + cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context', 'conv_history'] else cfg.model.observation_shape) + + # Slice the original observation tensor to obtain the batch for the initial inference. + obs_batch = obs_batch_ori[:, :stack_dim*history_length] + + # Initialize the target batch for consistency loss as `None`. It will only be set if consistency loss calculation is enabled. + obs_target_batch = None + # If the model configuration specifies the use of self-supervised learning loss, prepare the target batch for the consistency loss. + if cfg.model.self_supervised_learning_loss: + # Determine the starting dimension to exclude based on the model type. + # For 'conv', exclude the first 'image_channel' dimensions. + # For 'mlp', exclude the first 'observation_shape' dimensions. + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context', 'conv_history'] else cfg.model.observation_shape + + # Slice the original observation tensor to obtain the batch for consistency loss calculation. + obs_target_batch = obs_batch_ori[:, exclude_dim*history_length:] + + # Return the prepared batches: one for the initial inference and one for the consistency loss calculation (if applicable). + return obs_batch, obs_target_batch + + def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: diff --git a/zoo/atari/config/atari_muzero_history_stack1_config.py b/zoo/atari/config/atari_muzero_history_stack1_config.py new file mode 100644 index 000000000..27d7f56ed --- /dev/null +++ b/zoo/atari/config/atari_muzero_history_stack1_config.py @@ -0,0 +1,124 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map +env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here +action_space_size = atari_env_action_space_map[env_id] + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.1 +# replay_ratio = 0.25 +batch_size = 256 +max_env_step = int(5e5) +reanalyze_ratio = 0. +history_length=4 +# history_length=3 +# history_length=2 + +# only for debug +# collector_env_num = 4 +# n_episode = 4 +# evaluator_env_num = 3 +# num_simulations = 3 +# update_per_collect = 2 +# replay_ratio = 0. +# batch_size = 3 +# max_env_step = int(5e5) +# reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_muzero_config = dict( + exp_name=f'data_lz/data_muzero_history_20250324/{env_id[:-14]}_muzero_HL{history_length}_useaug_final-LN_ns{num_simulations}_upc{update_per_collect}_rer{reanalyze_ratio}_stack1_seed0', + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + frame_stack_num=1, + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + collect_max_episode_steps=int(2e4), + eval_max_episode_steps=int(1e4), + # debug + # collect_max_episode_steps=int(1200), + # eval_max_episode_steps=int(1200), + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + model=dict( + observation_shape=(3, 64, 64), + image_channel=3, + frame_stack_num=1, + gray_scale=False, + action_space_size=action_space_size, + downsample=True, + # self_supervised_learning_loss=True, + self_supervised_learning_loss=False, + discrete_action_encoding_type='one_hot', + norm_type='BN', + # norm_type='LN', + model_type='conv_history', + history_length=history_length, + num_unroll_steps=5, + ), + history_length=history_length, + num_unroll_steps=5, + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + cuda=True, + env_type='not_board_games', + game_segment_length=400, + # game_segment_length=20, # TODO: debug + random_collect_episode_num=0, + use_augmentation=True, + # use_augmentation=False, + update_per_collect=update_per_collect, + replay_ratio=replay_ratio, + batch_size=batch_size, + # optim_type='Adam', + # learning_rate=0.0001, + # piecewise_decay_lr_scheduler=False, + optim_type='SGD', + piecewise_decay_lr_scheduler=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, + n_episode=n_episode, + eval_freq=int(5e3), + # eval_freq=int(1e2), # TODO + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_muzero_config = EasyDict(atari_muzero_config) +main_config = atari_muzero_config + +atari_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero_history', + import_names=['lzero.policy.muzero_history'], + ), +) +atari_muzero_create_config = EasyDict(atari_muzero_create_config) +create_config = atari_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, model_path=main_config.policy.model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/memory/config/memory_muzero_history_config.py b/zoo/memory/config/memory_muzero_history_config.py new file mode 100644 index 000000000..3ddc6881d --- /dev/null +++ b/zoo/memory/config/memory_muzero_history_config.py @@ -0,0 +1,139 @@ +from easydict import EasyDict + +env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door' +memory_length = 60 +max_env_step = int(5e5) + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 8 +num_simulations = 50 +update_per_collect = None # for others +replay_ratio = 0.25 +# batch_size = 256 +reanalyze_ratio = 0 +td_steps = 5 +game_segment_length = 30+memory_length + +# num_unroll_steps = 16+memory_length +# TODO +num_unroll_steps = 5 + +policy_entropy_weight = 1e-4 +threshold_training_steps_for_final_temperature = int(1e5) +eps_greedy_exploration_in_collect = True +# history_length = 20 +# history_length = 40 +# history_length = 60 +history_length = 70 +batch_size = 128 + +# debug +# num_simulations = 3 +# batch_size = 3 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +memory_muzero_config = dict( + exp_name=f'data_lz/data_muzero_history_20250324/{env_id}_memlen-{memory_length}_muzero_HL{history_length}_transformer_ns{num_simulations}_upc{update_per_collect}_seed{seed}', + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 5, 5), + frame_stack_num=1, + image_channel=3, + gray_scale=False, + # rgb_img_observation=False, # Whether to return RGB image observation + rgb_img_observation=True, # Whether to return RGB image observation + scale_rgb_img_observation=True, # Whether to scale the RGB image observation to [0, 1] + # flatten_observation=True, # Whether to flatten the observation + flatten_observation=False, # Whether to flatten the observation + max_frames={ + # "explore": 15, # for key_to_door + "explore": 1, # for visual_match + "distractor": memory_length, + "reward": 15 + }, # Maximum frames per phase + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + sample_type='episode', # NOTE: very important for memory env + num_unroll_steps=num_unroll_steps, + history_length=history_length, + model=dict( + observation_shape=(3, 5, 5), + image_channel=3, + frame_stack_num=1, + gray_scale=False, + action_space_size=4, + analysis_sim_norm=False, + # model_type='mlp', + latent_state_dim=128, + discrete_action_encoding_type='one_hot', + norm_type='BN', + # norm_type='LN', + # self_supervised_learning_loss=True, # NOTE: default is False. + self_supervised_learning_loss=False, + downsample=False, + model_type='conv_history', + history_length=history_length, + fusion_mode= 'transformer', # 可选: 'mean', 'transformer', 其它未来方式 + num_unroll_steps=num_unroll_steps, + ), + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + decay=int(5e4), # NOTE: 50k env steps for key_to_door + ), + policy_entropy_weight=policy_entropy_weight, + td_steps=td_steps, + cuda=True, + env_type='not_board_games', + game_segment_length=game_segment_length, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='AdamW', + piecewise_decay_lr_scheduler=False, + learning_rate=0.0001, + ssl_loss_weight=2, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +memory_muzero_config = EasyDict(memory_muzero_config) +main_config = memory_muzero_config + +memory_muzero_create_config = dict( + env=dict( + type='memory_lightzero', + import_names=['zoo.memory.envs.memory_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero_history', + import_names=['lzero.policy.muzero_history'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +memory_muzero_create_config = EasyDict(memory_muzero_create_config) +create_config = memory_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/memory/envs/memory_lightzero_env.py b/zoo/memory/envs/memory_lightzero_env.py index f47b19c98..dad3dd1f9 100644 --- a/zoo/memory/envs/memory_lightzero_env.py +++ b/zoo/memory/envs/memory_lightzero_env.py @@ -142,15 +142,26 @@ def reset(self) -> np.ndarray: self._gif_images = [] self._gif_images_numpy = [] + # if self._save_replay or self._render or self.rgb_img_observation: + # # Convert observation to RGB format + # obs_rgb = np.zeros((5, 5, 3), dtype=np.uint8) + # for char, color in self._game._colours.items(): + # # NOTE: self._game._colours is a dictionary that maps the characters in the game to their corresponding (true) colors, ranging in [0,999]. + # # Because the np.uint8 type array will perform a modulo 256 operation (taking the remainder), that is to say, + # # any value greater than 255 will be subtracted by an integer multiple of 256 until the value falls within the range of 0-255. + # # For example, 1000 will become 232 (because 1000 % 256 = 232) + # obs_rgb[observation.reshape(5, 5) == ord(char)] = color + if self._save_replay or self._render or self.rgb_img_observation: # Convert observation to RGB format obs_rgb = np.zeros((5, 5, 3), dtype=np.uint8) for char, color in self._game._colours.items(): - # NOTE: self._game._colours is a dictionary that maps the characters in the game to their corresponding (true) colors, ranging in [0,999]. - # Because the np.uint8 type array will perform a modulo 256 operation (taking the remainder), that is to say, - # any value greater than 255 will be subtracted by an integer multiple of 256 until the value falls within the range of 0-255. - # For example, 1000 will become 232 (because 1000 % 256 = 232) - obs_rgb[observation.reshape(5, 5) == ord(char)] = color + # NOTE: self._game._colours 中的颜色值范围在 [0, 999], + # 直接赋值给 np.uint8 类型数组会在未来版本中报错。可以使用 np.array(value).astype(np.uint8) 进行转换, + # 或者明确执行模运算以获取 uint8 范围内的数值 + obs_rgb[observation.reshape(5, 5) == ord(char)] = np.array(color).astype(np.uint8) + # 或者使用模运算的方式: + # obs_rgb[observation.reshape(5, 5) == ord(char)] = color % 256 if self._save_replay or self._render: img = Image.fromarray(obs_rgb)