Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
da4cdba
feature(pu): add unizero/muzero multitask pipeline and net plasticity…
Apr 25, 2025
a6eed25
fix(pu): fix some adaptation bug
Apr 25, 2025
67a0e9a
feature(pu): add unizero multitask balance pipeline for atari and dmc
Apr 29, 2025
f083096
fix(pu): fix some adaptation bug
Apr 29, 2025
37eb118
feature(pu): add vit encoder for unizero
Apr 29, 2025
f32d63e
polish(pu): polish moe layer in transformer
May 1, 2025
c0aa747
feature(pu): add eval norm mean/medium for atari
May 5, 2025
8b3cff6
fix(pu): fix atari norm mean/median, fix collect in balance pipeline
May 7, 2025
f2c158b
polish(pu): polish config
May 7, 2025
20b42f7
fix(pu): fix dmc multitask to be compatiable with timestep (which is …
May 7, 2025
39ee55e
polish(pu): polish config
May 13, 2025
e85c449
fix(pu): fix task_id bug in balance pipeline, and polish benchmark_na…
May 14, 2025
c16d564
fix(pu): fix benchmark_name option
May 14, 2025
474b81c
polish(pu): fix norm score computation, adapt config to aliyun
May 21, 2025
50e367e
polish(pu): polish unizero_mt balance pipeline use CurriculumControll…
May 23, 2025
9171c3e
tmp
May 30, 2025
bc5003a
Merge branch 'dev-multitask-balance-clean' of https://github.com/open…
May 30, 2025
158e4a0
tmp
Jun 1, 2025
d66b986
tmp
Jun 4, 2025
0d5ede0
test(pu): add vit moe test
Jun 5, 2025
ca6ddb6
polish(pu): add adapter_scales to tb
Jun 11, 2025
7dd6c04
feature(pu): add atari uz balance config
Jun 12, 2025
c8e7cb8
polish(pu): add stable_adaptor_scale
Jun 19, 2025
0313335
tmp
Jun 23, 2025
ef170fd
sync code
Jun 25, 2025
bbec353
polish(pu): use freeze_non_lora_parameters in transformer, not use Le…
zjowowen Jul 30, 2025
20648d5
feature(pu): add vit-encoder lora in balance pipeline
zjowowen Jul 30, 2025
db6032a
polish(pu): fix reanalyze index bug, fix global_solved bug, add apply…
Aug 5, 2025
f63b544
polish(pu): add collect/eval_num_simulations option
Aug 5, 2025
bbbe505
polish(pu): polish comments and style in entry of scalezero
puyuan1996 Sep 28, 2025
bf9f965
polish(pu): polish comments and style of ctree/tree_search/buffer/com…
puyuan1996 Sep 28, 2025
fb04c7a
polish(pu): polish comments and style of files in lzero.model
puyuan1996 Sep 28, 2025
06148e7
polish(pu): polish comments and style of files in lzero.model.unizero…
puyuan1996 Sep 28, 2025
471ae6a
polish(pu): polish comments and style of unizero_world_models
puyuan1996 Sep 28, 2025
07933a5
polish(pu): polish comments and style of files in policy/
puyuan1996 Sep 28, 2025
df3b644
polish(pu): polish comments and style of files in worker
puyuan1996 Sep 28, 2025
4f89dcc
polish(pu): polish comments and style of files in configs
puyuan1996 Sep 28, 2025
e7a8796
Merge remote-tracking branch 'origin/main' into dev-multitask-balance…
puyuan1996 Sep 28, 2025
ab746d1
fix(pu): fix some merge typo
tAnGjIa520 Sep 28, 2025
0476aca
fix(pu): fix ln norm_type, fix kv_cache rewrite bug, add value_priori…
tAnGjIa520 Sep 28, 2025
2c0a965
fix(pu): fix unizero_mt
tAnGjIa520 Sep 28, 2025
84e6094
polish(pu): add LN in head, polish init_weight, polish adamw
tAnGjIa520 Sep 29, 2025
05da638
fix(pu): fix configure_optimizer_unizero in unizero_mt
tAnGjIa520 Oct 2, 2025
06ad080
feature(pu): add encoder-clip, label smooth, analyze_latent_represent…
tAnGjIa520 Oct 9, 2025
9f69f5a
feature(pu): add encoder-clip, label smooth option in unizero_multit…
tAnGjIa520 Oct 9, 2025
af99278
fix(pu): fix tb log when gpu_num<task_num, fix total_loss += bug, polish
tAnGjIa520 Oct 9, 2025
bf91ca2
polish(pu):polish config
tAnGjIa520 Oct 9, 2025
b18f892
fix(pu): fix encoder-clip bug and num_channel/res bug
tAnGjIa520 Oct 11, 2025
bf3cd12
polish(pu): polish scale_factor in DPS
tAnGjIa520 Oct 12, 2025
b1efa60
tmp
tAnGjIa520 Oct 18, 2025
c2f9817
feature(pu): add some analysis metrics in tensorboard for unizero and…
tAnGjIa520 Oct 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1450,4 +1450,4 @@ events.*
!/assets/pooltool/**
lzero/mcts/ctree/ctree_alphazero/pybind11

zoo/jericho/envs/z-machine-games-master
zoo/jericho/envs/z-machine-games-master
6 changes: 5 additions & 1 deletion lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .utils import *
from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp
from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp
from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval
from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp
from .utils import *
563 changes: 563 additions & 0 deletions lzero/entry/train_muzero_multitask_segment_ddp.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def train_unizero(
else:
world_size = 1
rank = 0
# TODO: for visualize
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
# import sys; sys.exit(0)

while True:
# Log memory usage of the replay buffer
Expand Down
548 changes: 548 additions & 0 deletions lzero/entry/train_unizero_multitask_balance_segment_ddp.py

Large diffs are not rendered by default.

890 changes: 890 additions & 0 deletions lzero/entry/train_unizero_multitask_segment_ddp.py

Large diffs are not rendered by default.

408 changes: 408 additions & 0 deletions lzero/entry/train_unizero_multitask_segment_eval.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion lzero/entry/train_unizero_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def train_unizero_segment(
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
# if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):

stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
Expand Down
845 changes: 706 additions & 139 deletions lzero/entry/utils.py

Large diffs are not rendered by default.

293 changes: 170 additions & 123 deletions lzero/mcts/buffer/game_buffer.py

Large diffs are not rendered by default.

90 changes: 59 additions & 31 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def __init__(self, cfg: dict):
self.sample_times = 0
self.active_root_num = 0

if hasattr(self._cfg, 'task_id'):
self.task_id = self._cfg.task_id
print(f"Task ID is set to {self.task_id}.")
try:
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
except Exception as e:
self.action_space_size = self._cfg.model.action_space_size

else:
self.task_id = None
print("No task_id found in configuration. Task ID is set to None.")
self.action_space_size = self._cfg.model.action_space_size
self.value_support = DiscreteSupport(*self._cfg.model.value_support_range)
self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range)

Expand Down Expand Up @@ -149,7 +161,7 @@ def sample(
self.compute_target_re_time += self._compute_target_timer.value

batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self._cfg.model.action_space_size
policy_non_re_context, self.action_space_size
)

# fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies
Expand Down Expand Up @@ -469,17 +481,21 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
end_index = self._cfg.mini_infer_size * (i + 1)
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device)
# calculate the target value
m_output = model.initial_inference(m_obs)

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self.value_support),
m_output.policy_logits
]
)
if self.task_id is not None:
m_output = model.initial_inference(m_obs, task_id=self.task_id)
else:
m_output = model.initial_inference(m_obs)


# if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)

network_output.append(m_output)

Expand Down Expand Up @@ -594,25 +610,28 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
beg_index = self._cfg.mini_infer_size * i
end_index = self._cfg.mini_infer_size * (i + 1)
m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device)
m_output = model.initial_inference(m_obs)

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self.value_support),
m_output.policy_logits
]
)
if self.task_id is not None:
m_output = model.initial_inference(m_obs, task_id=self.task_id)
else:
m_output = model.initial_inference(m_obs)

# if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)

network_output.append(m_output)

_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero')
reward_pool = reward_pool.squeeze().tolist()
policy_logits_pool = policy_logits_pool.tolist()
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size
).astype(np.float32).tolist() for _ in range(transition_batch_size)
]
if self._cfg.mcts_ctree:
Expand All @@ -624,7 +643,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
with self._origin_search_timer:
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
if self.task_id is not None:
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id)
else:
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)

self.origin_search_time += self._origin_search_timer.value
else:
# python mcts_tree
Expand All @@ -634,7 +657,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
else:
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)
if self.task_id is not None:
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id)
else:
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)


roots_legal_actions_list = legal_actions
roots_distributions = roots.get_distributions()
Expand All @@ -650,7 +677,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

if policy_mask[policy_index] == 0:
# NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
target_policies.append([0 for _ in range(self._cfg.model.action_space_size)])
target_policies.append([0 for _ in range(self.action_space_size)])
else:
# NOTE: It is very important to use the latest MCTS visit count distribution.
sum_visits = sum(distributions)
Expand All @@ -659,7 +686,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
if distributions is None:
# if at some obs, the legal_action is None, add the fake target_policy
target_policies.append(
list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
list(np.ones(self.action_space_size) / self.action_space_size)
)
else:
# Update the data in game segment:
Expand All @@ -676,7 +703,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
target_policies.append(policy)
else:
# for board games that have two players and legal_actions is dy
policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)]
policy_tmp = [0 for _ in range(self.action_space_size)]
# to make sure target_policies have the same dimension
sum_visits = sum(distributions)
policy = [visit_count / sum_visits for visit_count in distributions]
Expand Down Expand Up @@ -705,7 +732,7 @@ def _compute_target_policy_non_reanalyzed(
- game_segment_lens
- action_mask_segment
- to_play_segment
- policy_shape: self._cfg.model.action_space_size
- policy_shape: self.action_space_size
Returns:
- batch_target_policies_non_re
"""
Expand All @@ -728,7 +755,7 @@ def _compute_target_policy_non_reanalyzed(
]
# NOTE: in continuous action space env: we set all legal_actions as -1
legal_actions = [
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
[-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size)
]
else:
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
Expand Down Expand Up @@ -778,6 +805,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -
NOTE:
train_data = [current_batch, target_batch]
current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list]
target_batch = [batch_rewards, batch_target_values, batch_target_policies]
"""
indices = train_data[0][-3]
metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities}
Expand Down
Loading
Loading