@@ -197,11 +197,11 @@ def sample(
197197 )
198198
199199 # current_batch = [
200- # obs_list, action_list, root_sampled_actions_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list
200+ # obs_list, action_list, root_sampled_actions_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list
201201 # ]
202202 # target reward, target value
203203 batch_rewards , batch_target_values = self ._compute_target_reward_value (
204- reward_value_context , policy ._target_model , current_batch [3 ] # current_batch[3] is batch_target_action
204+ reward_value_context , policy ._target_model , current_batch [3 ], current_batch [ - 1 ] # current_batch[3] is batch_target_action
205205 )
206206
207207 batch_target_policies_non_re = self ._compute_target_policy_non_reanalyzed (
@@ -250,6 +250,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
250250 batch_size = len (batch_index_list )
251251 obs_list , action_list , mask_list = [], [], []
252252 root_sampled_actions_list = []
253+ timestep_list = []
253254 bootstrap_action_list = []
254255
255256 # prepare the inputs of a batch
@@ -259,7 +260,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
259260
260261 actions_tmp = game .action_segment [pos_in_game_segment :pos_in_game_segment +
261262 self ._cfg .num_unroll_steps ].tolist ()
262-
263+ timestep_tmp = game .timestep_segment [pos_in_game_segment :pos_in_game_segment +
264+ self ._cfg .num_unroll_steps ].tolist ()
263265 # NOTE: self._cfg.num_unroll_steps + 1
264266 root_sampled_actions_tmp = game .root_sampled_actions [pos_in_game_segment :pos_in_game_segment +
265267 self ._cfg .num_unroll_steps + 1 ]
@@ -297,6 +299,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
297299 reshape = reshape
298300 )
299301
302+ # TODO: check the effect
303+ timestep_tmp += [
304+ 0 for _ in range (self ._cfg .num_unroll_steps - len (timestep_tmp ))
305+ ]
300306 # obtain the input observations
301307 # pad if length of obs in game_segment is less than stack+num_unroll_steps
302308 # e.g. stack+num_unroll_steps = 4+5
@@ -309,6 +315,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
309315 root_sampled_actions_list .append (root_sampled_actions_tmp )
310316
311317 mask_list .append (mask_tmp )
318+ timestep_list .append (timestep_tmp )
312319
313320 # NOTE: for unizero
314321 bootstrap_action_tmp = game .action_segment [pos_in_game_segment + self ._cfg .td_steps :pos_in_game_segment +
@@ -329,7 +336,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
329336 # ==============================================================
330337 # formalize the inputs of a batch
331338 current_batch = [
332- obs_list , action_list , root_sampled_actions_list , bootstrap_action_list , mask_list , batch_index_list , weights_list , make_time_list
339+ obs_list , action_list , root_sampled_actions_list , bootstrap_action_list , mask_list , batch_index_list , weights_list , make_time_list , timestep_list
333340 ]
334341 for i in range (len (current_batch )):
335342 current_batch [i ] = np .asarray (current_batch [i ])
@@ -397,13 +404,16 @@ def _prepare_policy_reanalyzed_context(
397404
398405 # for board games
399406 action_mask_segment , to_play_segment = [], []
407+ timestep_segment = []
408+
400409 for game_segment , state_index in zip (game_segment_list , pos_in_game_segment_list ):
401410 game_segment_len = len (game_segment )
402411 game_segment_lens .append (game_segment_len )
403412 rewards .append (game_segment .reward_segment )
404413 # for board games
405414 action_mask_segment .append (game_segment .action_mask_segment )
406415 to_play_segment .append (game_segment .to_play_segment )
416+ timestep_segment .append (game_segment .timestep_segment )
407417 child_visits .append (game_segment .child_visit_segment )
408418 root_sampled_actions .append (game_segment .root_sampled_actions )
409419
@@ -425,11 +435,11 @@ def _prepare_policy_reanalyzed_context(
425435
426436 policy_re_context = [
427437 policy_obs_list , policy_mask , pos_in_game_segment_list , batch_index_list , child_visits , root_sampled_actions , root_values , game_segment_lens ,
428- action_mask_segment , to_play_segment
438+ action_mask_segment , to_play_segment , timestep_segment
429439 ]
430440 return policy_re_context
431441
432- def _compute_target_policy_reanalyzed (self , policy_re_context : List [Any ], model : Any , action_batch ) -> np .ndarray :
442+ def _compute_target_policy_reanalyzed (self , policy_re_context : List [Any ], model : Any , batch_action ) -> np .ndarray :
433443 """
434444 Overview:
435445 prepare policy targets from the reanalyzed context of policies
@@ -444,7 +454,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
444454
445455 # for board games
446456 policy_obs_list , policy_mask , pos_in_game_segment_list , batch_index_list , child_visits , root_sampled_actions , root_values , game_segment_lens , action_mask_segment , \
447- to_play_segment = policy_re_context # noqa
457+ to_play_segment , timestep_segment = policy_re_context # noqa
448458 transition_batch_size = len (policy_obs_list )
449459 game_segment_batch_size = len (pos_in_game_segment_list )
450460
@@ -474,9 +484,9 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
474484
475485 # =============== NOTE: The key difference with MuZero =================
476486 # calculate the target value
477- # action_batch .shape (32, 10)
487+ # batch_action .shape (32, 10)
478488 # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352
479- m_output = model .initial_inference (batch_obs , action_batch [:self .reanalyze_num ]) # NOTE: :self.reanalyze_num
489+ m_output = model .initial_inference (batch_obs , batch_action [:self .reanalyze_num ]) # NOTE: :self.reanalyze_num
480490 # =======================================================================
481491
482492 # if not in training, obtain the scalars of the value/reward
@@ -592,7 +602,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
592602 return batch_target_policies_re , root_sampled_actions
593603
594604
595- def _compute_target_reward_value (self , reward_value_context : List [Any ], model : Any , action_batch ) -> Tuple [
605+ def _compute_target_reward_value (self , reward_value_context : List [Any ], model : Any , batch_action , batch_timestep ) -> Tuple [
596606 Any , Any ]:
597607 """
598608 Overview:
@@ -634,7 +644,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
634644 # =============== NOTE: The key difference with MuZero =================
635645 # calculate the target value
636646 # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352
637- m_output = model .initial_inference (batch_obs , action_batch )
647+ m_output = model .initial_inference (batch_obs , batch_action , start_pos = batch_timestep )
638648 # ======================================================================
639649
640650 # if not in training, obtain the scalars of the value/reward
0 commit comments