Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,17 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
# For some environments (e.g., Jericho), the action space size may be different.
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
else:
# For environments with a fixed action space (e.g., Atari),
# we can safely sample from the entire game segment range.
if pos_in_game_segment >= self._cfg.game_segment_length:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()

pos_in_game_segment_list.append(pos_in_game_segment)

Expand Down
23 changes: 16 additions & 7 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def sample(
policy._target_model.eval()

# obtain the current_batch and prepare target context
reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
reward_value_context, policy_re_context, policy_non_re_context, current_batch, batch_manual_embeds = self._make_batch(
batch_size, self._cfg.reanalyze_ratio
)

# current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]

# target reward, target value
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action
reward_value_context, policy._target_model, current_batch[2], current_batch[-1], batch_manual_embeds=batch_manual_embeds # current_batch[2] is batch_target_action
)

# target policy
Expand All @@ -92,7 +92,7 @@ def sample(
target_batch = [batch_rewards, batch_target_values, batch_target_policies]

# a batch contains the current_batch and the target_batch
train_data = [current_batch, target_batch]
train_data = [current_batch, target_batch, batch_manual_embeds]
return train_data

def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
Expand Down Expand Up @@ -120,6 +120,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
obs_list, action_list, mask_list = [], [], []
timestep_list = []
bootstrap_action_list = []
manual_embeds_list = []

# prepare the inputs of a batch
for i in range(batch_size):
Expand Down Expand Up @@ -156,6 +157,13 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
manual_embeds_tmp = game.manual_embeds_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps+1]
manual_embeds_tmp2 = np.array([torch.zeros(self._cfg.model.world_model_cfg.manual_embed_dim) for _ in range(self._cfg.num_unroll_steps +1 - len(manual_embeds_tmp))])
if len(manual_embeds_tmp2) > 0:
manual_embeds_tmp = np.concatenate([manual_embeds_tmp, manual_embeds_tmp2], axis=0)
manual_embeds_list.append(manual_embeds_tmp)

action_list.append(actions_tmp)

mask_list.append(mask_tmp)
Expand Down Expand Up @@ -214,7 +222,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
else:
policy_non_re_context = None

context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
manual_embeds_array = np.asarray(manual_embeds_list)
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch, manual_embeds_array
return context

def reanalyze_buffer(
Expand Down Expand Up @@ -432,7 +441,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
# =============== NOTE: The key difference with MuZero =================
# To obtain the target policy from MCTS guided by the recent target model
# TODO: batch_obs (policy_obs_list) is at timestep t, batch_action is at timestep t
m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num
m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num], manual_embeds=self.manual_embeds) # NOTE: :self.reanalyze_num
# =======================================================================

if not model.training:
Expand Down Expand Up @@ -514,7 +523,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

return batch_target_policies_re

def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action, batch_timestep) -> Tuple[
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action, batch_timestep, batch_manual_embeds=None) -> Tuple[
Any, Any]:
"""
Overview:
Expand All @@ -540,7 +549,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
# =============== NOTE: The key difference with MuZero =================
# calculate the bootstrapped value and target value
# NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps
m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep)
m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, manual_embeds=batch_manual_embeds)
# ======================================================================

# if not in training, obtain the scalars of the value/reward
Expand Down
8 changes: 7 additions & 1 deletion lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
self.action_mask_segment = []
self.to_play_segment = []
self.timestep_segment = []
self.manual_embeds_segment = []

self.target_values = []
self.target_rewards = []
Expand Down Expand Up @@ -138,6 +139,7 @@ def append(
to_play: int = -1,
timestep: int = 0,
chance: int = 0,
manual_embeds = None,
) -> None:
"""
Overview:
Expand All @@ -150,6 +152,7 @@ def append(
self.action_mask_segment.append(action_mask)
self.to_play_segment.append(to_play)
self.timestep_segment.append(timestep)
self.manual_embeds_segment.append(manual_embeds)

if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment.append(chance)
Expand Down Expand Up @@ -285,6 +288,7 @@ def game_segment_to_array(self) -> None:
self.obs_segment = np.array(self.obs_segment)
self.action_segment = np.array(self.action_segment)
self.reward_segment = np.array(self.reward_segment)
self.manual_embeds_segment = np.array(self.manual_embeds_segment)

# Check if all elements in self.child_visit_segment have the same length
if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment):
Expand All @@ -305,7 +309,7 @@ def game_segment_to_array(self) -> None:
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)

def reset(self, init_observations: np.ndarray) -> None:
def reset(self, init_observations: np.ndarray, init_manual_embeds = None) -> None:
"""
Overview:
Initialize the game segment using ``init_observations``,
Expand All @@ -323,6 +327,7 @@ def reset(self, init_observations: np.ndarray) -> None:
self.action_mask_segment = []
self.to_play_segment = []
self.timestep_segment = []
self.manual_embeds_segment = []

if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []
Expand All @@ -331,6 +336,7 @@ def reset(self, init_observations: np.ndarray) -> None:

for observation in init_observations:
self.obs_segment.append(copy.deepcopy(observation))
self.manual_embeds_segment.append(init_manual_embeds)

def is_full(self) -> bool:
"""
Expand Down
10 changes: 9 additions & 1 deletion lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def __init__(
embedding_dim: int = 256,
group_size: int = 8,
final_norm_option_in_encoder: str = 'LayerNorm', # TODO
use_manual: bool = False,
manual_dim: int = 768
) -> None:
"""
Overview:
Expand All @@ -496,8 +498,10 @@ def __init__(
logging.info(f"Using norm type: {norm_type}")
logging.info(f"Using activation type: {activation}")

self.observation_shape = observation_shape
self.observation_shape = observation_shape
self.downsample = downsample
self.use_manual = use_manual

if self.downsample:
self.downsample_net = DownSample(
observation_shape,
Expand Down Expand Up @@ -533,6 +537,8 @@ def __init__(

elif self.observation_shape[1] in [84, 96]:
self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)
elif self.observation_shape[1] == 10:
self.last_linear = nn.Linear(64 * 10 * 10, self.embedding_dim, bias=False)

self.final_norm_option_in_encoder = final_norm_option_in_encoder
if self.final_norm_option_in_encoder == 'LayerNorm':
Expand All @@ -542,6 +548,8 @@ def __init__(
else:
raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}")

if use_manual:
self.feature_merge_linear = nn.Linear(self.embedding_dim + manual_dim, self.embedding_dim)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.feature_merge_linearh后面应该和原来的obs_embeddings执行相同的norm?

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand Down
9 changes: 5 additions & 4 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
norm_type=norm_type,
embedding_dim=world_model_cfg.embed_dim,
group_size=world_model_cfg.group_size,
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,
)

# ====== for analysis ======
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
print('==' * 20)

def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None,
current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0) -> MZNetworkOutput:
current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0, manual_embeds: List[torch.Tensor] = None) -> MZNetworkOutput:
"""
Overview:
Initial inference of the UniZero model, which is the first step of the UniZero model.
Expand Down Expand Up @@ -205,7 +205,8 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torc
obs_act_dict = {
'obs': obs_batch,
'action': action_batch,
'current_obs': current_obs_batch
'current_obs': current_obs_batch,
'manual_embeds': manual_embeds
}

# Perform initial inference using the world model
Expand Down
56 changes: 53 additions & 3 deletions lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Union, Optional, List, Tuple, Any
from typing import Dict, Union, Optional, List, Tuple, Any, Set

import numpy as np
import torch
Expand Down Expand Up @@ -65,7 +65,14 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
self.precompute_pos_emb_diff_kv()
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")


self.continuous_action_space = self.config.continuous_action_space
if self.use_manual:
self.manual_fuse_proj = nn.Linear(self.embed_dim + self.manual_embed_dim, self.embed_dim, bias=False)
# self.manual_fuse_proj = nn.Sequential(
# nn.Linear(self.embed_dim + self.manual_embed_dim, self.embed_dim, bias=False),
# nn.GELU(approximate='tanh')
# )

# Initialize action embedding table
if self.continuous_action_space:
Expand Down Expand Up @@ -100,7 +107,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
skip_modules = set()
if hasattr(self.tokenizer.encoder, 'pretrained_model'):
skip_modules.update(self.tokenizer.encoder.pretrained_model.modules())
if hasattr(self.tokenizer, 'decoder_network'):
if hasattr(self.tokenizer, 'decoder_network') and self.tokenizer.decoder_network is not None:
skip_modules.update(self.tokenizer.decoder_network.modules())

def custom_init(module):
Expand Down Expand Up @@ -292,6 +299,11 @@ def _initialize_config_parameters(self) -> None:
self.obs_per_embdding_dim = self.config.embed_dim
self.sim_norm = SimNorm(simnorm_dim=self.group_size)

# ====== [NEW] manual fusion 开关与层 ======
self.use_manual = self.config.use_manual
self.manual_embed_dim = self.config.manual_embed_dim


def _initialize_patterns(self) -> None:
"""Initialize patterns for block masks."""
self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block)
Expand Down Expand Up @@ -750,6 +762,34 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va
else:
return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos)

def manual_fuse(self, obs_embeddings: torch.Tensor, manual_embeds: List[torch.Tensor] = None):
"""
Fuse manual embeddings with observation embeddings.

Arguments:
- obs_embeddings (:obj:`torch.Tensor`): Observation embeddings.
- ready_env_id (:obj:`torch.Tensor`): IDs of environments that are ready.
Returns:
- torch.Tensor: Fused embeddings.
"""
b, s, _ = obs_embeddings.shape
if manual_embeds is not None:
if isinstance(manual_embeds, list):
manual_embeds_array = manual_embeds[0]
manual_embeds_expanded = torch.from_numpy(manual_embeds_array).view(1, 1, -1)
manual_embeds_expanded = manual_embeds_expanded.expand(b, s, manual_embeds_expanded.shape[-1]).to(obs_embeddings.device)

elif isinstance(manual_embeds, np.ndarray):

manual_embeds_expanded = torch.from_numpy(manual_embeds).reshape(b, s, -1).to(obs_embeddings.device)




new_obs_embeddings = torch.cat([manual_embeds_expanded, obs_embeddings], dim=-1)
return self.manual_fuse_proj(new_obs_embeddings)


@torch.no_grad()
def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor:
"""
Expand All @@ -765,14 +805,20 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos
batch_obs = obs_act_dict['obs'] # obs_act_dict['obs'] is at timestep t
batch_action = obs_act_dict['action'] # obs_act_dict['action'] is at timestep t
batch_current_obs = obs_act_dict['current_obs'] # obs_act_dict['current_obs'] is at timestep t+1
manual_embeds = obs_act_dict['manual_embeds'] if 'manual_embeds' in obs_act_dict else None


# Encode observations to latent embeddings.
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs)
if self.use_manual:
obs_embeddings = self.manual_fuse(obs_embeddings, manual_embeds)

if batch_current_obs is not None:
# ================ Collect and Evaluation Phase ================
# Encode current observations to latent embeddings
current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs)
if self.use_manual:
current_obs_embeddings = self.manual_fuse(current_obs_embeddings, manual_embeds)
# print(f"current_obs_embeddings.device: {current_obs_embeddings.device}")
self.latent_state = current_obs_embeddings
outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action,
Expand Down Expand Up @@ -823,7 +869,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens
ready_env_num = current_obs_embeddings.shape[0]
self.keys_values_wm_list = []
self.keys_values_wm_size_list = []

assert len(last_obs_embeddings) == len(current_obs_embeddings)
for i in range(ready_env_num):
# Retrieve latent state for a single environment
# TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done
Expand Down Expand Up @@ -1294,6 +1340,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
start_pos = batch['timestep']
# Encode observations into latent state representations
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'])
if self.use_manual:
obs_embeddings = self.manual_fuse(obs_embeddings, manual_embeds=batch['manual_embeds'])

# ========= for visual analysis =========
# Uncomment the lines below for visual analysis in Pong
Expand Down Expand Up @@ -1437,6 +1485,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
# For training stability, use target_tokenizer to compute the true next latent state representations
with torch.no_grad():
target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'])
if self.use_manual:
target_obs_embeddings = self.manual_fuse(target_obs_embeddings, manual_embeds=batch['manual_embeds'])

# Compute labels for observations, rewards, and ends
labels_observations, labels_rewards, _ = self.compute_labels_world_model(target_obs_embeddings,
Expand Down
Loading