Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
da4cdba
feature(pu): add unizero/muzero multitask pipeline and net plasticity…
puyuan1996 Apr 25, 2025
a6eed25
fix(pu): fix some adaptation bug
puyuan1996 Apr 25, 2025
67a0e9a
feature(pu): add unizero multitask balance pipeline for atari and dmc
puyuan1996 Apr 29, 2025
f083096
fix(pu): fix some adaptation bug
puyuan1996 Apr 29, 2025
37eb118
feature(pu): add vit encoder for unizero
puyuan1996 Apr 29, 2025
f32d63e
polish(pu): polish moe layer in transformer
puyuan1996 May 1, 2025
d002719
feature(xjy): add multi-task learning pipeline in jericho environment
xiongjyu May 4, 2025
c0aa747
feature(pu): add eval norm mean/medium for atari
puyuan1996 May 5, 2025
8b3cff6
fix(pu): fix atari norm mean/median, fix collect in balance pipeline
puyuan1996 May 7, 2025
f2c158b
polish(pu): polish config
puyuan1996 May 7, 2025
20b42f7
fix(pu): fix dmc multitask to be compatiable with timestep (which is …
puyuan1996 May 7, 2025
39ee55e
polish(pu): polish config
puyuan1996 May 13, 2025
e85c449
fix(pu): fix task_id bug in balance pipeline, and polish benchmark_na…
puyuan1996 May 14, 2025
c16d564
fix(pu): fix benchmark_name option
puyuan1996 May 14, 2025
4adb3dc
Standardized the format and added the ability to use moe in unizero
xiongjyu May 21, 2025
474b81c
polish(pu): fix norm score computation, adapt config to aliyun
puyuan1996 May 21, 2025
50e367e
polish(pu): polish unizero_mt balance pipeline use CurriculumControll…
puyuan1996 May 23, 2025
9171c3e
tmp
puyuan1996 May 30, 2025
bc5003a
Merge branch 'dev-multitask-balance-clean' of https://github.com/open…
puyuan1996 May 30, 2025
158e4a0
tmp
puyuan1996 Jun 1, 2025
285cd77
Merge remote-tracking branch 'origin/dev-multitask-balance-clean' int…
xiongjyu Jun 3, 2025
d66b986
tmp
puyuan1996 Jun 4, 2025
0d5ede0
test(pu): add vit moe test
puyuan1996 Jun 5, 2025
a820eca
fixed a bug in calculating dormant
xiongjyu Jun 9, 2025
ca6ddb6
polish(pu): add adapter_scales to tb
puyuan1996 Jun 11, 2025
7dd6c04
feature(pu): add atari uz balance config
puyuan1996 Jun 12, 2025
c8e7cb8
polish(pu): add stable_adaptor_scale
puyuan1996 Jun 19, 2025
0313335
tmp
puyuan1996 Jun 23, 2025
ef170fd
sync code
puyuan1996 Jun 25, 2025
bbec353
polish(pu): use freeze_non_lora_parameters in transformer, not use Le…
zjowowen Jul 30, 2025
20648d5
feature(pu): add vit-encoder lora in balance pipeline
zjowowen Jul 30, 2025
db6032a
polish(pu): fix reanalyze index bug, fix global_solved bug, add apply…
puyuan1996 Aug 5, 2025
f63b544
polish(pu): add collect/eval_num_simulations option
puyuan1996 Aug 5, 2025
201f5e3
modify the format as required
xiongjyu Sep 23, 2025
1b52f03
Merge remote-tracking branch 'upstream/dev-multitask-balance-clean' i…
xiongjyu Sep 23, 2025
e0a498e
add _log_model_parameters and polish LN
xiongjyu Nov 30, 2025
232679d
Merge remote-tracking branch 'upstream/main' into dev-multitask-clean-v3
xiongjyu Jan 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1450,4 +1450,4 @@ events.*
!/assets/pooltool/**
lzero/mcts/ctree/ctree_alphazero/pybind11

zoo/jericho/envs/z-machine-games-master
zoo/jericho/envs/z-machine-games-master
5 changes: 5 additions & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp
from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp
from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval
from .train_unizero_multitask_ddp import train_unizero_multitask_ddp
from .train_unizero_multitask import train_unizero_multitask
from .utils import *
80 changes: 80 additions & 0 deletions lzero/entry/compute_task_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@



import numpy as np
import torch


def symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog 归一化,减少目标值的幅度差异。
symlog(x) = sign(x) * log(|x| + 1)
"""
return torch.sign(x) * torch.log(torch.abs(x) + 1)


def inv_symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog 的逆操作,用于恢复原始值。
inv_symlog(x) = sign(x) * (exp(|x|) - 1)
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)


def compute_task_weights(
task_rewards: dict,
epsilon: float = 1e-6,
min_weight: float = 0.1,
max_weight: float = 0.5,
temperature: float = 1.0,
use_symlog: bool = True,
) -> dict:
"""
改进后的任务权重计算函数,加入 symlog 处理和鲁棒性设计。

Args:
task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。
epsilon (float): 避免分母为零的小值。
min_weight (float): 权重的最小值,用于裁剪。
max_weight (float): 权重的最大值,用于裁剪。
temperature (float): 控制权重分布的温度系数。
use_symlog (bool): 是否使用 symlog 对 task_rewards 进行矫正。

Returns:
dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。
"""
# Step 1: 矫正奖励值(可选,使用 symlog)
if use_symlog:
rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32)
corrected_rewards = symlog(rewards_tensor).numpy() # 使用 symlog 矫正
task_rewards = dict(zip(task_rewards.keys(), corrected_rewards))

# Step 2: 计算初始权重(反比例关系)
raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()}

# Step 3: 温度缩放
scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()}

# Step 4: 归一化权重
total_weight = sum(scaled_weights.values())
normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()}

# Step 5: 裁剪权重,确保在 [min_weight, max_weight] 范围内
clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()}

final_weights = clipped_weights
return final_weights

task_rewards_list = [
{"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300},
{"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000},
{"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10},
]

for i, task_rewards in enumerate(task_rewards_list, start=1):
print(f"Case {i}: Original Rewards: {task_rewards}")
print("Original Weights:")
print(compute_task_weights(task_rewards, use_symlog=False))
print("Improved Weights with Symlog:")
print(compute_task_weights(task_rewards, use_symlog=True))
print()
Loading