Skip to content

Commit 48cdf06

Browse files
committed
fix(pu): fix norm and init shape bug
1 parent 556b2ec commit 48cdf06

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

lzero/model/common.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,11 @@ def __init__(
297297

298298
# Initial convolution: stride 2
299299
self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False)
300-
self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2)
300+
if norm_type == 'BN':
301+
self.norm1 = nn.BatchNorm2d(out_channels // 2)
302+
elif norm_type == 'LN':
303+
self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2],
304+
eps=1e-5)
301305

302306
# Stage 1 with residual blocks
303307
self.resblocks1 = nn.ModuleList([
@@ -734,7 +738,15 @@ def __init__(
734738
self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type)
735739
else:
736740
self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
737-
self.norm = build_normalization(norm_type, dim=2)(num_channels)
741+
if norm_type == 'BN':
742+
self.norm = nn.BatchNorm2d(num_channels)
743+
elif norm_type == 'LN':
744+
if downsample:
745+
self.norm = nn.LayerNorm(
746+
[num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
747+
eps=1e-5)
748+
else:
749+
self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
738750

739751
self.resblocks = nn.ModuleList([
740752
ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False)

lzero/policy/unizero.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,10 +1335,10 @@ def _init_collect(self) -> None:
13351335
self._collect_epsilon = 0.0
13361336
self.collector_env_num = self._cfg.collector_env_num
13371337
if self._cfg.model.model_type == 'conv':
1338-
self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
1338+
self.last_batch_obs_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
13391339
self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)]
13401340
elif self._cfg.model.model_type == 'mlp':
1341-
self.last_batch_obs = torch.full(
1341+
self.last_batch_obs_collect = torch.full(
13421342
[self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id,
13431343
).to(self._cfg.device)
13441344
self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)]
@@ -1390,7 +1390,7 @@ def _forward_collect(
13901390
output = {i: None for i in ready_env_id}
13911391

13921392
with torch.no_grad():
1393-
network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action_collect, data, timestep)
1393+
network_output = self._collect_model.initial_inference(self.last_batch_obs_collect, self.last_batch_action_collect, data, timestep)
13941394
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
13951395

13961396
pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy()
@@ -1461,7 +1461,7 @@ def _forward_collect(
14611461
}
14621462
batch_action.append(action)
14631463

1464-
self.last_batch_obs = data
1464+
self.last_batch_obs_collect = data
14651465
self.last_batch_action_collect = batch_action
14661466

14671467
# This logic is a temporary workaround specific to the muzero_segment_collector.
@@ -1505,10 +1505,10 @@ def _init_eval(self) -> None:
15051505
self.evaluator_env_num = self._cfg.evaluator_env_num
15061506

15071507
if self._cfg.model.model_type == 'conv':
1508-
self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
1508+
self.last_batch_obs_eval = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
15091509
self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)]
15101510
elif self._cfg.model.model_type == 'mlp':
1511-
self.last_batch_obs = torch.full(
1511+
self.last_batch_obs_eval = torch.full(
15121512
[self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id,
15131513
).to(self._cfg.device)
15141514
self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)]
@@ -1623,13 +1623,13 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
16231623
- reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset.
16241624
"""
16251625
if reset_init_data:
1626-
self.last_batch_obs = initialize_pad_batch(
1626+
self.last_batch_obs_collect = initialize_pad_batch(
16271627
self._cfg.model.observation_shape,
16281628
self._cfg.collector_env_num,
16291629
self._cfg.device,
16301630
pad_token_id=self.pad_token_id
16311631
)
1632-
self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)]
1632+
self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)]
16331633

16341634

16351635
# We must handle both single int and list of ints for env_id.
@@ -1696,7 +1696,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
16961696
self._cfg.device,
16971697
pad_token_id=self.pad_token_id
16981698
)
1699-
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
1699+
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}')
17001700

17011701
else:
17021702
self.last_batch_obs_eval = initialize_pad_batch(
@@ -1705,9 +1705,9 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
17051705
self._cfg.device,
17061706
pad_token_id=self.pad_token_id
17071707
)
1708-
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
1708+
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}')
17091709

1710-
self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)]
1710+
self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)]
17111711

17121712
# This logic handles the crucial end-of-episode cache clearing for evaluation.
17131713
# The evaluator calls `_policy.reset([env_id])` when an episode is done.

0 commit comments

Comments
 (0)