diff --git a/lzero/model/common.py b/lzero/model/common.py index 31f254963..c33a03c26 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -297,7 +297,11 @@ def __init__( # Initial convolution: stride 2 self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2) + if norm_type == 'BN': + self.norm1 = nn.BatchNorm2d(out_channels // 2) + elif norm_type == 'LN': + self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], + eps=1e-5) # Stage 1 with residual blocks self.resblocks1 = nn.ModuleList([ @@ -734,7 +738,15 @@ def __init__( self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) - self.norm = build_normalization(norm_type, dim=2)(num_channels) + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) self.resblocks = nn.ModuleList([ ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index f2dc027c1..992e07ffe 100755 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -1335,10 +1335,10 @@ def _init_collect(self) -> None: self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_obs_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.full( + self.last_batch_obs_collect = torch.full( [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, ).to(self._cfg.device) self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] @@ -1390,7 +1390,7 @@ def _forward_collect( output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action_collect, data, timestep) + network_output = self._collect_model.initial_inference(self.last_batch_obs_collect, self.last_batch_action_collect, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() @@ -1461,7 +1461,7 @@ def _forward_collect( } batch_action.append(action) - self.last_batch_obs = data + self.last_batch_obs_collect = data self.last_batch_action_collect = batch_action # This logic is a temporary workaround specific to the muzero_segment_collector. @@ -1505,10 +1505,10 @@ def _init_eval(self) -> None: self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_obs_eval = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.full( + self.last_batch_obs_eval = torch.full( [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, ).to(self._cfg.device) self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)] @@ -1623,13 +1623,13 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_pad_batch( + self.last_batch_obs_collect = initialize_pad_batch( self._cfg.model.observation_shape, self._cfg.collector_env_num, self._cfg.device, pad_token_id=self.pad_token_id ) - self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)] # We must handle both single int and list of ints for env_id. @@ -1696,7 +1696,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ self._cfg.device, pad_token_id=self.pad_token_id ) - logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}') else: self.last_batch_obs_eval = initialize_pad_batch( @@ -1705,9 +1705,9 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ self._cfg.device, pad_token_id=self.pad_token_id ) - logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}') - self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)] # This logic handles the crucial end-of-episode cache clearing for evaluation. # The evaluator calls `_policy.reset([env_id])` when an episode is done.