Skip to content

Commit d9957ef

Browse files
puyuan1996dyyounggruiheng123puyuanPaParaZz1
authored
feature(pu/wrh): add rope that use the true timestep as pos_index (#266)
* feature(pu): add rope in unizero's transformer * feature(wrh): add RoPE for unizero * fix(pu): use true timestep index in rope * polish(pu): add rope_embed support for cartpole * polish(pu): rename step_index to timestep * fix(pu): fix start_pos *2 bug * fix(pu): fix rope option in obs/act case * feature(pu): adapt rope to dmc in continuous action space * fix(pu): fix start_pos bug in obs_embeddings * fix(pu): fix rope pos_index in training phase and add readme for pos_embed * fix(pu): fix rope pos_index in reanalyze_phase * fix(pu): fix rope pos_index in recurrent_inference * fix(pu): fix rope pos_index use search_depth in recurrent_inference and reanalyze_phase * polish(pu): polish configs * polish(pu): polish rope start_pos and config * polish(pu): polish variable names, rebundant code, and comment style in rope * polish(pu): delete wrongly added file * polish(pu): polish unizero_world_model readme --------- Co-authored-by: dyyoungg <[email protected]> Co-authored-by: wrh12345 <[email protected]> Co-authored-by: puyuan <[email protected]> Co-authored-by: PaParaZz1 <[email protected]>
1 parent cdde886 commit d9957ef

27 files changed

+911
-308
lines changed

lzero/entry/train_unizero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def train_unizero(
163163
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
164164

165165
# Evaluate policy performance
166-
if evaluator.should_eval(learner.train_iter):
166+
if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
167167
logging.info(f"Training iteration {learner.train_iter}: Starting evaluation...")
168168
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
169169
logging.info(f"Training iteration {learner.train_iter}: Evaluation completed, stop condition: {stop}, current reward: {reward}")
@@ -209,7 +209,7 @@ def train_unizero(
209209
# Execute multiple training rounds
210210
for i in range(update_per_collect):
211211
train_data = replay_buffer.sample(batch_size, policy)
212-
if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
212+
if replay_buffer._cfg.reanalyze_ratio > 0 and i % 20 == 0:
213213
policy.recompute_pos_emb_diff_and_clear_cache()
214214

215215
if cfg.policy.use_wandb:

lzero/entry/train_unizero_segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def train_unizero_segment(
175175
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
176176
else:
177177
# Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch
178-
if train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
178+
if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
179179
with timer:
180180
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
181181
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)

lzero/entry/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]],
107107
elif isinstance(observation_shape, int):
108108
shape = [batch_size, observation_shape]
109109
else:
110-
raise TypeError("observation_shape must be either an int or a list")
110+
raise TypeError(f"observation_shape must be either an int, a list, or a tuple, but got {type(observation_shape).__name__}")
111111

112112
return torch.zeros(shape).to(device)
113113

lzero/mcts/buffer/game_buffer_sampled_unizero.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ def sample(
197197
)
198198

199199
# current_batch = [
200-
# obs_list, action_list, root_sampled_actions_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list
200+
# obs_list, action_list, root_sampled_actions_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list
201201
# ]
202202
# target reward, target value
203203
batch_rewards, batch_target_values = self._compute_target_reward_value(
204-
reward_value_context, policy._target_model, current_batch[3] # current_batch[3] is batch_target_action
204+
reward_value_context, policy._target_model, current_batch[3], current_batch[-1] # current_batch[3] is batch_target_action
205205
)
206206

207207
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
@@ -250,6 +250,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
250250
batch_size = len(batch_index_list)
251251
obs_list, action_list, mask_list = [], [], []
252252
root_sampled_actions_list = []
253+
timestep_list = []
253254
bootstrap_action_list = []
254255

255256
# prepare the inputs of a batch
@@ -259,7 +260,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
259260

260261
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
261262
self._cfg.num_unroll_steps].tolist()
262-
263+
timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment +
264+
self._cfg.num_unroll_steps].tolist()
263265
# NOTE: self._cfg.num_unroll_steps + 1
264266
root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment +
265267
self._cfg.num_unroll_steps + 1]
@@ -297,6 +299,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
297299
reshape=reshape
298300
)
299301

302+
# TODO: check the effect
303+
timestep_tmp += [
304+
0 for _ in range(self._cfg.num_unroll_steps - len(timestep_tmp))
305+
]
300306
# obtain the input observations
301307
# pad if length of obs in game_segment is less than stack+num_unroll_steps
302308
# e.g. stack+num_unroll_steps = 4+5
@@ -309,6 +315,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
309315
root_sampled_actions_list.append(root_sampled_actions_tmp)
310316

311317
mask_list.append(mask_tmp)
318+
timestep_list.append(timestep_tmp)
312319

313320
# NOTE: for unizero
314321
bootstrap_action_tmp = game.action_segment[pos_in_game_segment+self._cfg.td_steps:pos_in_game_segment +
@@ -329,7 +336,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
329336
# ==============================================================
330337
# formalize the inputs of a batch
331338
current_batch = [
332-
obs_list, action_list, root_sampled_actions_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list
339+
obs_list, action_list, root_sampled_actions_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list
333340
]
334341
for i in range(len(current_batch)):
335342
current_batch[i] = np.asarray(current_batch[i])
@@ -397,13 +404,16 @@ def _prepare_policy_reanalyzed_context(
397404

398405
# for board games
399406
action_mask_segment, to_play_segment = [], []
407+
timestep_segment = []
408+
400409
for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list):
401410
game_segment_len = len(game_segment)
402411
game_segment_lens.append(game_segment_len)
403412
rewards.append(game_segment.reward_segment)
404413
# for board games
405414
action_mask_segment.append(game_segment.action_mask_segment)
406415
to_play_segment.append(game_segment.to_play_segment)
416+
timestep_segment.append(game_segment.timestep_segment)
407417
child_visits.append(game_segment.child_visit_segment)
408418
root_sampled_actions.append(game_segment.root_sampled_actions)
409419

@@ -425,11 +435,11 @@ def _prepare_policy_reanalyzed_context(
425435

426436
policy_re_context = [
427437
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_sampled_actions, root_values, game_segment_lens,
428-
action_mask_segment, to_play_segment
438+
action_mask_segment, to_play_segment, timestep_segment
429439
]
430440
return policy_re_context
431441

432-
def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray:
442+
def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, batch_action) -> np.ndarray:
433443
"""
434444
Overview:
435445
prepare policy targets from the reanalyzed context of policies
@@ -444,7 +454,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
444454

445455
# for board games
446456
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_sampled_actions, root_values, game_segment_lens, action_mask_segment, \
447-
to_play_segment = policy_re_context # noqa
457+
to_play_segment, timestep_segment = policy_re_context # noqa
448458
transition_batch_size = len(policy_obs_list)
449459
game_segment_batch_size = len(pos_in_game_segment_list)
450460

@@ -474,9 +484,9 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
474484

475485
# =============== NOTE: The key difference with MuZero =================
476486
# calculate the target value
477-
# action_batch.shape (32, 10)
487+
# batch_action.shape (32, 10)
478488
# batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352
479-
m_output = model.initial_inference(batch_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num
489+
m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num
480490
# =======================================================================
481491

482492
# if not in training, obtain the scalars of the value/reward
@@ -592,7 +602,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
592602
return batch_target_policies_re, root_sampled_actions
593603

594604

595-
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[
605+
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action, batch_timestep) -> Tuple[
596606
Any, Any]:
597607
"""
598608
Overview:
@@ -634,7 +644,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
634644
# =============== NOTE: The key difference with MuZero =================
635645
# calculate the target value
636646
# batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352
637-
m_output = model.initial_inference(batch_obs, action_batch)
647+
m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep)
638648
# ======================================================================
639649

640650
# if not in training, obtain the scalars of the value/reward

0 commit comments

Comments
 (0)