Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,11 @@ def __init__(

# Initial convolution: stride 2
self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2)
if norm_type == 'BN':
self.norm1 = nn.BatchNorm2d(out_channels // 2)
elif norm_type == 'LN':
self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2],
eps=1e-5)

# Stage 1 with residual blocks
self.resblocks1 = nn.ModuleList([
Expand Down Expand Up @@ -734,7 +738,15 @@ def __init__(
self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type)
else:
self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm = build_normalization(norm_type, dim=2)(num_channels)
if norm_type == 'BN':
self.norm = nn.BatchNorm2d(num_channels)
elif norm_type == 'LN':
if downsample:
self.norm = nn.LayerNorm(
[num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
eps=1e-5)
else:
self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)

self.resblocks = nn.ModuleList([
ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False)
Expand Down
22 changes: 11 additions & 11 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,10 +1335,10 @@ def _init_collect(self) -> None:
self._collect_epsilon = 0.0
self.collector_env_num = self._cfg.collector_env_num
if self._cfg.model.model_type == 'conv':
self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_obs_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)]
elif self._cfg.model.model_type == 'mlp':
self.last_batch_obs = torch.full(
self.last_batch_obs_collect = torch.full(
[self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id,
).to(self._cfg.device)
self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)]
Expand Down Expand Up @@ -1390,7 +1390,7 @@ def _forward_collect(
output = {i: None for i in ready_env_id}

with torch.no_grad():
network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action_collect, data, timestep)
network_output = self._collect_model.initial_inference(self.last_batch_obs_collect, self.last_batch_action_collect, data, timestep)
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)

pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy()
Expand Down Expand Up @@ -1461,7 +1461,7 @@ def _forward_collect(
}
batch_action.append(action)

self.last_batch_obs = data
self.last_batch_obs_collect = data
self.last_batch_action_collect = batch_action

# This logic is a temporary workaround specific to the muzero_segment_collector.
Expand Down Expand Up @@ -1505,10 +1505,10 @@ def _init_eval(self) -> None:
self.evaluator_env_num = self._cfg.evaluator_env_num

if self._cfg.model.model_type == 'conv':
self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_obs_eval = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)]
elif self._cfg.model.model_type == 'mlp':
self.last_batch_obs = torch.full(
self.last_batch_obs_eval = torch.full(
[self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id,
).to(self._cfg.device)
self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)]
Expand Down Expand Up @@ -1623,13 +1623,13 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
- reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset.
"""
if reset_init_data:
self.last_batch_obs = initialize_pad_batch(
self.last_batch_obs_collect = initialize_pad_batch(
self._cfg.model.observation_shape,
self._cfg.collector_env_num,
self._cfg.device,
pad_token_id=self.pad_token_id
)
self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)]
self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)]


# We must handle both single int and list of ints for env_id.
Expand Down Expand Up @@ -1696,7 +1696,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
self._cfg.device,
pad_token_id=self.pad_token_id
)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}')

else:
self.last_batch_obs_eval = initialize_pad_batch(
Expand All @@ -1705,9 +1705,9 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
self._cfg.device,
pad_token_id=self.pad_token_id
)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}')

self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)]
self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)]

# This logic handles the crucial end-of-episode cache clearing for evaluation.
# The evaluator calls `_policy.reset([env_id])` when an episode is done.
Expand Down
Loading