Skip to content

Commit eeca7d4

Browse files
SteamMachinistIaroslav Roshchupkin
andauthored
fix(SteamMachinist): fix inconsistent parameter name in MuZeroCollector and MuZeroPolicy (#463)
* fix: rename param x->data in policy forward call * fix: fix uninitialized variable collected_step --------- Co-authored-by: Iaroslav Roshchupkin <iroshchupkin@neoflex.ru>
1 parent 81db0b2 commit eeca7d4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

lzero/worker/muzero_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def collect(
340340

341341
# --- Initializations ---
342342
collected_episode = 0
343+
collected_step = 0
343344
env_nums = self._env_num
344345
retry_waiting_time = 0.05
345346

@@ -411,7 +412,7 @@ def collect(
411412
# Policy Forward Pass
412413
# ==============================================================
413414
policy_input = {
414-
'x': stack_obs_tensor,
415+
'data': stack_obs_tensor,
415416
'action_mask': action_mask,
416417
'temperature': temperature,
417418
'to_play': to_play,
@@ -679,4 +680,4 @@ def _output_log(self, train_iter: int) -> None:
679680

680681
if self.policy_config.use_wandb:
681682
wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()}
682-
wandb.log(wandb_log_data, step=self._total_envstep_count)
683+
wandb.log(wandb_log_data, step=self._total_envstep_count)

0 commit comments

Comments
 (0)