Skip to content

Commit ec60f8d

Browse files
puyuan1996puyuan
andauthored
fix(pu): fix prepare_obs_stack_for_unizero (#328)
Co-authored-by: puyuan <[email protected]>
1 parent f803504 commit ec60f8d

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

lzero/policy/sampled_unizero.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from lzero.model import ImageTransforms
1414
from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \
1515
DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs, \
16-
prepare_obs_stack4_for_unizero
16+
prepare_obs_stack_for_unizero
1717
from lzero.policy.unizero import UniZeroPolicy
1818
from .utils import configure_optimizers_nanogpt
1919
from lzero.entry.utils import initialize_zeros_batch
@@ -385,8 +385,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
385385
target_reward, target_value, target_policy = target_batch
386386

387387
# Prepare observations based on frame stack number
388-
if self._cfg.model.frame_stack_num == 4:
389-
obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg)
388+
if self._cfg.model.frame_stack_num > 1:
389+
obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg)
390390
else:
391391
obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg)
392392

lzero/policy/unizero.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from lzero.model import ImageTransforms
1414
from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \
1515
DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs, \
16-
prepare_obs_stack4_for_unizero
16+
prepare_obs_stack_for_unizero
1717
from lzero.policy.muzero import MuZeroPolicy
1818
from .utils import configure_optimizers_nanogpt
1919

@@ -357,8 +357,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
357357
target_reward, target_value, target_policy = target_batch
358358

359359
# Prepare observations based on frame stack number
360-
if self._cfg.model.frame_stack_num == 4:
361-
obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg)
360+
if self._cfg.model.frame_stack_num > 1:
361+
obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg)
362362
else:
363363
obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) # TODO: optimize
364364

lzero/policy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def configure_optimizers(
317317
return optimizer
318318

319319

320-
def prepare_obs_stack4_for_unizero(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]:
320+
def prepare_obs_stack_for_unizero(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]:
321321
"""
322322
Overview:
323323
Prepare the observation stack for UniZero model. This function processes the original batch of observations

0 commit comments

Comments
 (0)