@@ -1335,10 +1335,10 @@ def _init_collect(self) -> None:
13351335 self ._collect_epsilon = 0.0
13361336 self .collector_env_num = self ._cfg .collector_env_num
13371337 if self ._cfg .model .model_type == 'conv' :
1338- self .last_batch_obs = torch .zeros ([self .collector_env_num , self ._cfg .model .observation_shape [0 ], 64 , 64 ]).to (self ._cfg .device )
1338+ self .last_batch_obs_collect = torch .zeros ([self .collector_env_num , self ._cfg .model .observation_shape [0 ], 64 , 64 ]).to (self ._cfg .device )
13391339 self .last_batch_action_collect = [- 1 for i in range (self .collector_env_num )]
13401340 elif self ._cfg .model .model_type == 'mlp' :
1341- self .last_batch_obs = torch .full (
1341+ self .last_batch_obs_collect = torch .full (
13421342 [self .collector_env_num , self ._cfg .model .observation_shape ], fill_value = self .pad_token_id ,
13431343 ).to (self ._cfg .device )
13441344 self .last_batch_action_collect = [- 1 for i in range (self .collector_env_num )]
@@ -1390,7 +1390,7 @@ def _forward_collect(
13901390 output = {i : None for i in ready_env_id }
13911391
13921392 with torch .no_grad ():
1393- network_output = self ._collect_model .initial_inference (self .last_batch_obs , self .last_batch_action_collect , data , timestep )
1393+ network_output = self ._collect_model .initial_inference (self .last_batch_obs_collect , self .last_batch_action_collect , data , timestep )
13941394 latent_state_roots , reward_roots , pred_values , policy_logits = mz_network_output_unpack (network_output )
13951395
13961396 pred_values = self .value_inverse_scalar_transform_handle (pred_values ).detach ().cpu ().numpy ()
@@ -1461,7 +1461,7 @@ def _forward_collect(
14611461 }
14621462 batch_action .append (action )
14631463
1464- self .last_batch_obs = data
1464+ self .last_batch_obs_collect = data
14651465 self .last_batch_action_collect = batch_action
14661466
14671467 # This logic is a temporary workaround specific to the muzero_segment_collector.
@@ -1505,10 +1505,10 @@ def _init_eval(self) -> None:
15051505 self .evaluator_env_num = self ._cfg .evaluator_env_num
15061506
15071507 if self ._cfg .model .model_type == 'conv' :
1508- self .last_batch_obs = torch .zeros ([self .collector_env_num , self ._cfg .model .observation_shape [0 ], 64 , 64 ]).to (self ._cfg .device )
1508+ self .last_batch_obs_eval = torch .zeros ([self .collector_env_num , self ._cfg .model .observation_shape [0 ], 64 , 64 ]).to (self ._cfg .device )
15091509 self .last_batch_action_eval = [- 1 for i in range (self .collector_env_num )]
15101510 elif self ._cfg .model .model_type == 'mlp' :
1511- self .last_batch_obs = torch .full (
1511+ self .last_batch_obs_eval = torch .full (
15121512 [self .collector_env_num , self ._cfg .model .observation_shape ], fill_value = self .pad_token_id ,
15131513 ).to (self ._cfg .device )
15141514 self .last_batch_action_eval = [- 1 for i in range (self .collector_env_num )]
@@ -1623,13 +1623,13 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
16231623 - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset.
16241624 """
16251625 if reset_init_data :
1626- self .last_batch_obs = initialize_pad_batch (
1626+ self .last_batch_obs_collect = initialize_pad_batch (
16271627 self ._cfg .model .observation_shape ,
16281628 self ._cfg .collector_env_num ,
16291629 self ._cfg .device ,
16301630 pad_token_id = self .pad_token_id
16311631 )
1632- self .last_batch_action = [- 1 for _ in range (self ._cfg .collector_env_num )]
1632+ self .last_batch_action_collect = [- 1 for _ in range (self ._cfg .collector_env_num )]
16331633
16341634
16351635 # We must handle both single int and list of ints for env_id.
@@ -1696,7 +1696,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
16961696 self ._cfg .device ,
16971697 pad_token_id = self .pad_token_id
16981698 )
1699- logging .info (f'unizero.py task_id:{ task_id } after _reset_eval: last_batch_obs_eval:' , self .last_batch_obs_eval .shape )
1699+ logging .info (f'unizero.py task_id:{ task_id } after _reset_eval: last_batch_obs_eval:{ self .last_batch_obs_eval .shape } ' )
17001700
17011701 else :
17021702 self .last_batch_obs_eval = initialize_pad_batch (
@@ -1705,9 +1705,9 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
17051705 self ._cfg .device ,
17061706 pad_token_id = self .pad_token_id
17071707 )
1708- logging .info (f'unizero.py task_id:{ task_id } after _reset_eval: last_batch_obs_eval:' , self .last_batch_obs_eval .shape )
1708+ logging .info (f'unizero.py task_id:{ task_id } after _reset_eval: last_batch_obs_eval:{ self .last_batch_obs_eval .shape } ' )
17091709
1710- self .last_batch_action = [- 1 for _ in range (self ._cfg .evaluator_env_num )]
1710+ self .last_batch_action_eval = [- 1 for _ in range (self ._cfg .evaluator_env_num )]
17111711
17121712 # This logic handles the crucial end-of-episode cache clearing for evaluation.
17131713 # The evaluator calls `_policy.reset([env_id])` when an episode is done.
0 commit comments