diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 61ba751a9..1f7232db2 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -156,13 +156,17 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # For some environments (e.g., Jericho), the action space size may be different. # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), # we avoid sampling from the last `num_unroll_steps` steps of the game segment. - if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: - pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item() + if pos_in_game_segment >= len(game_segment.action_segment) - 1: + pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item() else: # For environments with a fixed action space (e.g., Atari), # we can safely sample from the entire game segment range. if pos_in_game_segment >= self._cfg.game_segment_length: pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + if pos_in_game_segment >= len(game_segment.action_segment) - 1: + pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 1e4c9d698..0ea0a5d1f 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -279,7 +279,10 @@ def _prepare_reward_value_context( td_steps_list, action_mask_segment, to_play_segment """ zero_obs = game_segment_list[0].zero_obs() + zero_manual = game_segment_list[0].zero_manual() + value_obs_list = [] + value_manual_embeds_list = [] # the value is valid or not (out of game_segment) value_mask = [] rewards_list = [] @@ -300,6 +303,7 @@ def _prepare_reward_value_context( # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) + game_manual_embeds = game_segment.get_unroll_manual(state_index + td_steps, self._cfg.num_unroll_steps) rewards_list.append(game_segment.reward_segment) @@ -321,15 +325,18 @@ def _prepare_reward_value_context( end_index = beg_index + self._cfg.model.frame_stack_num # the stacked obs in time t obs = game_obs[beg_index:end_index] + manual_embeds = game_manual_embeds[beg_index:end_index] else: value_mask.append(0) obs = zero_obs + manual_embeds = zero_manual value_obs_list.append(obs) + value_manual_embeds_list.append(manual_embeds) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, root_values, game_segment_lens, td_steps_list, - action_mask_segment, to_play_segment + action_mask_segment, to_play_segment, value_manual_embeds_list ] return reward_value_context diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 6208ce24a..d5f419cba 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -64,7 +64,7 @@ def sample( policy._target_model.eval() # obtain the current_batch and prepare target context - reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + reward_value_context, policy_re_context, policy_non_re_context, current_batch, batch_manual_embeds = self._make_batch( batch_size, self._cfg.reanalyze_ratio ) @@ -72,7 +72,7 @@ def sample( # target reward, target value batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action + reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action ) # target policy @@ -92,7 +92,7 @@ def sample( target_batch = [batch_rewards, batch_target_values, batch_target_policies] # a batch contains the current_batch and the target_batch - train_data = [current_batch, target_batch] + train_data = [current_batch, target_batch, batch_manual_embeds] return train_data def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: @@ -120,6 +120,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: obs_list, action_list, mask_list = [], [], [] timestep_list = [] bootstrap_action_list = [] + manual_embeds_list = [] # prepare the inputs of a batch for i in range(batch_size): @@ -156,6 +157,12 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True ) ) + manual_embeds_list.append( + game_segment_list[i].get_unroll_manual( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + ) + ) + action_list.append(actions_tmp) mask_list.append(mask_tmp) @@ -214,7 +221,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: else: policy_non_re_context = None - context = reward_value_context, policy_re_context, policy_non_re_context, current_batch + manual_embeds_array = np.asarray(manual_embeds_list) + context = reward_value_context, policy_re_context, policy_non_re_context, current_batch, manual_embeds_array return context def reanalyze_buffer( @@ -527,20 +535,21 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A - batch_target_values (:obj:'np.ndarray): batch of value estimation """ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, root_values, game_segment_lens, td_steps_list, action_mask_segment, \ - to_play_segment = reward_value_context # noqa + to_play_segment, value_manual_embeds_list = reward_value_context # noqa # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) transition_batch_size = len(value_obs_list) batch_target_values, batch_rewards = [], [] with torch.no_grad(): value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + batch_manual = torch.from_numpy(np.array(value_manual_embeds_list)) network_output = [] batch_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) # =============== NOTE: The key difference with MuZero ================= # calculate the bootstrapped value and target value # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, manual_embeds=batch_manual) # ====================================================================== # if not in training, obtain the scalars of the value/reward diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index ad216d196..13e8e3434 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -59,6 +59,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea # image obs input, e.g. atari environments self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + self.manual_embed_dim = config.model.world_model_cfg.manual_embed_dim self.obs_segment = [] self.action_segment = [] self.reward_segment = [] @@ -69,6 +70,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.action_mask_segment = [] self.to_play_segment = [] self.timestep_segment = [] + self.manual_embeds_segment = [] self.target_values = [] self.target_rewards = [] @@ -102,6 +104,23 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] return stacked_obs + def get_unroll_manual(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + """ + Overview: + Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. + Arguments: + - timestep (int): The time step. + - num_unroll_steps (int): The extra length of the observation frames. + - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. + """ + stacked_manual_embeds = self.manual_embeds_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_manual_embeds) + if pad_len > 0: + pad_frames = np.array([stacked_manual_embeds[-1] for _ in range(pad_len)]) + stacked_manual_embeds = np.concatenate((stacked_manual_embeds, pad_frames)) + return stacked_manual_embeds + def zero_obs(self) -> List: """ Overview: @@ -111,6 +130,15 @@ def zero_obs(self) -> List: """ return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)] + def zero_manual(self) -> List: + """ + Overview: + Return an manual embed frame filled with zeros. + Returns: + ndarray: An array filled with zeros. + """ + return [np.zeros((self.manual_embed_dim, ), dtype=np.float32) for _ in range(self.frame_stack_num)] + def get_obs(self) -> List: """ Overview: @@ -138,6 +166,7 @@ def append( to_play: int = -1, timestep: int = 0, chance: int = 0, + manual_embeds = None, ) -> None: """ Overview: @@ -150,6 +179,7 @@ def append( self.action_mask_segment.append(action_mask) self.to_play_segment.append(to_play) self.timestep_segment.append(timestep) + self.manual_embeds_segment.append(manual_embeds) if self.use_ture_chance_label_in_chance_encoder: self.chance_segment.append(chance) @@ -285,6 +315,7 @@ def game_segment_to_array(self) -> None: self.obs_segment = np.array(self.obs_segment) self.action_segment = np.array(self.action_segment) self.reward_segment = np.array(self.reward_segment) + self.manual_embeds_segment = np.array(self.manual_embeds_segment) # Check if all elements in self.child_visit_segment have the same length if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment): @@ -305,7 +336,7 @@ def game_segment_to_array(self) -> None: if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = np.array(self.chance_segment) - def reset(self, init_observations: np.ndarray) -> None: + def reset(self, init_observations: np.ndarray, init_manual_embeds = None) -> None: """ Overview: Initialize the game segment using ``init_observations``, @@ -323,6 +354,7 @@ def reset(self, init_observations: np.ndarray) -> None: self.action_mask_segment = [] self.to_play_segment = [] self.timestep_segment = [] + self.manual_embeds_segment = [] if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] @@ -331,6 +363,7 @@ def reset(self, init_observations: np.ndarray) -> None: for observation in init_observations: self.obs_segment.append(copy.deepcopy(observation)) + self.manual_embeds_segment.append(copy.deepcopy(init_manual_embeds)) def is_full(self) -> bool: """ diff --git a/lzero/model/common.py b/lzero/model/common.py index 795eb72a3..258606fdf 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -470,6 +470,8 @@ def __init__( embedding_dim: int = 256, group_size: int = 8, final_norm_option_in_encoder: str = 'LayerNorm', # TODO + use_manual: bool = False, + manual_dim: int = 768 ) -> None: """ Overview: @@ -496,8 +498,10 @@ def __init__( logging.info(f"Using norm type: {norm_type}") logging.info(f"Using activation type: {activation}") - self.observation_shape = observation_shape + self.observation_shape = observation_shape self.downsample = downsample + self.use_manual = use_manual + if self.downsample: self.downsample_net = DownSample( observation_shape, @@ -533,6 +537,8 @@ def __init__( elif self.observation_shape[1] in [84, 96]: self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + elif self.observation_shape[1] == 10: + self.last_linear = nn.Linear(64 * 10 * 10, self.embedding_dim, bias=False) self.final_norm_option_in_encoder = final_norm_option_in_encoder if self.final_norm_option_in_encoder == 'LayerNorm': @@ -542,6 +548,8 @@ def __init__( else: raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + if use_manual: + self.feature_merge_linear = nn.Linear(self.embedding_dim + manual_dim, self.embedding_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 4ea6500f3..77e2f7482 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List import torch import torch.nn as nn @@ -127,7 +127,7 @@ def __init__( norm_type=norm_type, embedding_dim=world_model_cfg.embed_dim, group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, ) # ====== for analysis ====== @@ -177,7 +177,7 @@ def __init__( print('==' * 20) def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, - current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0) -> MZNetworkOutput: + current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0, manual_embeds: List[torch.Tensor] = None) -> MZNetworkOutput: """ Overview: Initial inference of the UniZero model, which is the first step of the UniZero model. @@ -205,7 +205,8 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torc obs_act_dict = { 'obs': obs_batch, 'action': action_batch, - 'current_obs': current_obs_batch + 'current_obs': current_obs_batch, + 'manual_embeds': manual_embeds } # Perform initial inference using the world model diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index e8df2a6e0..8ed7ddf12 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Union, Optional, List, Tuple, Any +from typing import Dict, Union, Optional, List, Tuple, Any, Set import numpy as np import torch @@ -65,7 +65,21 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.precompute_pos_emb_diff_kv() print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self.continuous_action_space = self.config.continuous_action_space + if self.use_manual: + self.manual_fuse_proj = nn.Linear(self.embed_dim + self.manual_embed_dim, self.embed_dim, bias=False) + if self.final_norm_option_in_encoder == 'LayerNorm': + self.manual_embeds_norm = nn.LayerNorm(self.config.embed_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'SimNorm': + self.manual_embeds_norm = SimNorm(simnorm_dim=self.group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + + # self.manual_fuse_proj = nn.Sequential( + # nn.Linear(self.embed_dim + self.manual_embed_dim, self.embed_dim, bias=False), + # nn.GELU(approximate='tanh') + # ) # Initialize action embedding table if self.continuous_action_space: @@ -100,7 +114,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: skip_modules = set() if hasattr(self.tokenizer.encoder, 'pretrained_model'): skip_modules.update(self.tokenizer.encoder.pretrained_model.modules()) - if hasattr(self.tokenizer, 'decoder_network'): + if hasattr(self.tokenizer, 'decoder_network') and self.tokenizer.decoder_network is not None: skip_modules.update(self.tokenizer.decoder_network.modules()) def custom_init(module): @@ -292,6 +306,12 @@ def _initialize_config_parameters(self) -> None: self.obs_per_embdding_dim = self.config.embed_dim self.sim_norm = SimNorm(simnorm_dim=self.group_size) + # ====== [NEW] manual fusion 开关与层 ====== + self.use_manual = self.config.use_manual + self.manual_embed_dim = self.config.manual_embed_dim + self.final_norm_option_in_encoder = self.config.final_norm_option_in_encoder + + def _initialize_patterns(self) -> None: """Initialize patterns for block masks.""" self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) @@ -750,6 +770,33 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos) + def manual_fuse(self, obs_embeddings: torch.Tensor, manual_embeds: List[torch.Tensor] = None): + """ + Fuse manual embeddings with observation embeddings. + + Arguments: + - obs_embeddings (:obj:`torch.Tensor`): Observation embeddings. + - ready_env_id (:obj:`torch.Tensor`): IDs of environments that are ready. + Returns: + - torch.Tensor: Fused embeddings. + """ + b, s, _ = obs_embeddings.shape + if manual_embeds is not None: + if isinstance(manual_embeds, list): + manual_embeds_array = manual_embeds[0] + manual_embeds_expanded = torch.from_numpy(manual_embeds_array).view(1, 1, -1) + manual_embeds_expanded = manual_embeds_expanded.expand(b, s, manual_embeds_expanded.shape[-1]).to(obs_embeddings.device) + + elif isinstance(manual_embeds, np.ndarray): + manual_embeds_expanded = torch.from_numpy(manual_embeds).reshape(b, s, -1).to(obs_embeddings.device) + elif isinstance(manual_embeds, torch.Tensor): + manual_embeds_expanded = manual_embeds.reshape(b, s, -1).to(obs_embeddings.device) + + manual_embeds_expanded = self.manual_embeds_norm(manual_embeds_expanded) + new_obs_embeddings = torch.cat([manual_embeds_expanded, obs_embeddings], dim=-1) + return self.manual_fuse_proj(new_obs_embeddings) + + @torch.no_grad() def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor: """ @@ -765,14 +812,20 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos batch_obs = obs_act_dict['obs'] # obs_act_dict['obs'] is at timestep t batch_action = obs_act_dict['action'] # obs_act_dict['action'] is at timestep t batch_current_obs = obs_act_dict['current_obs'] # obs_act_dict['current_obs'] is at timestep t+1 + manual_embeds = obs_act_dict['manual_embeds'] if 'manual_embeds' in obs_act_dict else None + # Encode observations to latent embeddings. obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs) + if self.use_manual: + obs_embeddings = self.manual_fuse(obs_embeddings, manual_embeds) if batch_current_obs is not None: # ================ Collect and Evaluation Phase ================ # Encode current observations to latent embeddings current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs) + if self.use_manual: + current_obs_embeddings = self.manual_fuse(current_obs_embeddings, manual_embeds) # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") self.latent_state = current_obs_embeddings outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, @@ -823,7 +876,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens ready_env_num = current_obs_embeddings.shape[0] self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] - + assert len(last_obs_embeddings) == len(current_obs_embeddings) for i in range(ready_env_num): # Retrieve latent state for a single environment # TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done @@ -1289,11 +1342,13 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, return self.keys_values_wm_size_list - def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, target_model: nn.Module = None, **kwargs: Any) -> LossWithIntermediateLosses: start_pos = batch['timestep'] # Encode observations into latent state representations obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) + if self.use_manual: + obs_embeddings = self.manual_fuse(obs_embeddings, manual_embeds=batch['manual_embeds']) # ========= for visual analysis ========= # Uncomment the lines below for visual analysis in Pong @@ -1437,6 +1492,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # For training stability, use target_tokenizer to compute the true next latent state representations with torch.no_grad(): target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations']) + if self.use_manual: + target_obs_embeddings = target_model.manual_fuse(target_obs_embeddings, manual_embeds=batch['manual_embeds']) # Compute labels for observations, rewards, and ends labels_observations, labels_rewards, _ = self.compute_labels_world_model(target_obs_embeddings, diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 9ff2c1333..df3502f59 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -372,7 +372,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._learn_model.train() self._target_model.train() - current_batch, target_batch, train_iter = data + current_batch, target_batch, batch_manual_embeds, train_iter = data obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch @@ -423,6 +423,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + + batch_for_gpt['manual_embeds'] = batch_manual_embeds[:, :-1] if batch_manual_embeds is not None else None batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) batch_for_gpt['target_value'] = target_value_categorical[:, :-1] @@ -435,7 +437,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, self._target_model.world_model ) weighted_total_loss = losses.loss_total @@ -589,7 +591,8 @@ def _init_collect(self) -> None: self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + # self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]]).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) @@ -604,7 +607,8 @@ def _forward_collect( to_play: List = [-1], epsilon: float = 0.25, ready_env_id: np.ndarray = None, - timestep: List = [0] + timestep: List = [0], + manual_embeds: List[torch.Tensor] = None, ) -> Dict: """ Overview: @@ -641,7 +645,7 @@ def _forward_collect( output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep, manual_embeds=manual_embeds) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() @@ -745,14 +749,16 @@ def _init_eval(self) -> None: self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + # self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] elif self._cfg.model.model_type == 'mlp': self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], - ready_env_id: np.array = None, timestep: List = [0]) -> Dict: + ready_env_id: np.array = None, timestep: List = [0], manual_embeds: List[torch.Tensor] = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -783,7 +789,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep, manual_embeds=manual_embeds) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) # if not in training, obtain the scalars of the value/reward diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index c3b9bbd27..f10766d03 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -364,6 +364,8 @@ def collect(self, action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} timestep_dict = {} + manual_embeds_dict = {i: to_ndarray(init_obs[i].get('manual_embeds', None)) for i in range(env_nums)} + for i in range(env_nums): if 'timestep' not in init_obs[i]: print(f"Warning: 'timestep' key is missing in init_obs[{i}], assigning value -1") @@ -386,7 +388,7 @@ def collect(self, [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], maxlen=self.policy_config.model.frame_stack_num ) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_manual_embeds=init_obs[env_id].get('manual_embeds', None)) dones = np.array([False for _ in range(env_nums)]) last_game_segments = [None for _ in range(env_nums)] @@ -430,10 +432,12 @@ def collect(self, action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} + manual_embeds_dict = {env_id: manual_embeds_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] + manual_embeds = [manual_embeds_dict[env_id] for env_id in ready_env_id] if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} @@ -447,7 +451,7 @@ def collect(self, # Key policy forward step # ============================================================== # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep, manual_embeds=manual_embeds) pred_next_text_with_env_id = {k: v['predicted_next_text'] for k, v in policy_output.items()} @@ -566,12 +570,12 @@ def collect(self, if self.policy_config.use_ture_chance_label_in_chance_encoder: game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], chance_dict[env_id], timestep_dict[env_id] + to_play_dict[env_id], timestep_dict[env_id],chance_dict[env_id], manual_embeds=obs.get('manual_embeds', None) ) else: game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id] + to_play_dict[env_id], timestep_dict[env_id], manual_embeds=obs.get('manual_embeds', None) ) # NOTE: the position of code snippet is very important. @@ -579,6 +583,7 @@ def collect(self, action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) + if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(obs['chance']) @@ -638,7 +643,7 @@ def collect(self, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config ) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_manual_embeds=obs.get('manual_embeds', None)) self._env_info[env_id]['step'] += 1 if self.policy_config.model.world_model_cfg.obs_type == 'text': @@ -729,7 +734,7 @@ def collect(self, [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], maxlen=self.policy_config.model.frame_stack_num ) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_manual_embeds=init_obs[env_id].get('manual_embeds', None)) last_game_segments[env_id] = None last_game_priorities[env_id] = None diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 6ca7bcc71..0d6670ea1 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -246,6 +246,7 @@ def eval( to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} timestep_dict = {} + manual_embeds_dict = {i: to_ndarray(init_obs[i].get('manual_embeds', None)) for i in range(env_nums)} for i in range(env_nums): if 'timestep' not in init_obs[i]: print(f"Warning: 'timestep' key is missing in init_obs[{i}], assigning value -1") @@ -287,6 +288,9 @@ def eval( to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] + manual_embeds_dict = {env_id: manual_embeds_dict[env_id] for env_id in ready_env_id} + manual_embeds = [manual_embeds_dict[env_id] for env_id in ready_env_id] + stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() @@ -294,7 +298,7 @@ def eval( # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep, manual_embeds=manual_embeds) actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} diff --git a/zoo/messenger/__init__.py b/zoo/messenger/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/messenger/configs/messenger_unizero_config.py b/zoo/messenger/configs/messenger_unizero_config.py new file mode 100644 index 000000000..a5fae0d1f --- /dev/null +++ b/zoo/messenger/configs/messenger_unizero_config.py @@ -0,0 +1,172 @@ +import os +import argparse +from typing import Any, Dict + +from easydict import EasyDict + + +def main(env_id: str = 'messenger', seed: int = 0, max_env_step: int = int(1e6)) -> None: + """ + Main entry point for setting up environment configurations and launching training. + + Args: + env_id (str): Identifier of the environment, e.g., 'detective.z5'. + seed (int): Random seed used for reproducibility. + + Returns: + None + """ + collector_env_num: int = 1 # Number of collector environments + n_episode: int = collector_env_num + batch_size: int = 64 + env_id: str = 'messenger' + action_space_size: int = 5 + max_steps: int = 100 + use_manual: bool = True + task: str ='s1' + max_seq_len: int = 256 + + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + evaluator_env_num: int = 1 # Number of evaluator environments + num_simulations: int = 50 # Number of simulations + + # Project training parameters + num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion) + infer_context_length: int = 4 # Inference context length + + num_layers: int = 2 # Number of layers in the model + replay_ratio: float = 0.1 # Replay ratio for experience replay + embed_dim: int = 768 # Embedding dimension + + buffer_reanalyze_freq: float = 1 / 100000 + # reanalyze_batch_size: Number of sequences to reanalyze per reanalysis process + reanalyze_batch_size: int = 160 + # reanalyze_partition: Partition ratio from the replay buffer to use during reanalysis + reanalyze_partition: float = 0.75 + + model_name: str = 'BAAI/bge-base-en-v1.5' + + messenger_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(17, 10, 10), + max_steps=max_steps, + max_action_num=5, + n_entities=17, + mode='train', + task=task, + max_seq_len=max_seq_len, + model_path=model_name, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_gpu=False, + use_wandb=False, + accumulation_steps=1, + model=dict( + observation_shape=(17, 10, 10), + action_space_size=action_space_size, + downsample=False, + continuous_action_space=False, + image_channel=17, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=embed_dim, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + use_manual=use_manual, + manual_embed_dim=768, + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=int(collector_env_num*max_steps*replay_ratio ), + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_simulations=num_simulations, + n_episode=n_episode, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + ) + messenger_unizero_config = EasyDict(messenger_unizero_config) + main_config = messenger_unizero_config + + messenger_unizero_create_config = dict( + env=dict( + type='messenger', + import_names=['zoo.messenger.envs.messenger_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + messenger_unizero_create_config = EasyDict(messenger_unizero_create_config) + create_config = messenger_unizero_create_config + + main_config.exp_name = ( + f"./data_lz/data_unizero_messenger/{env_id}_use_manual_{use_manual}/uz_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" + f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" + f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" + ) + from lzero.entry import train_unizero + # Launch the training process + train_unizero( + [main_config, create_config], + seed=seed, + model_path=main_config.policy.model_path, + max_env_step=max_env_step, + ) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + torchrun --nproc_per_node=4 ./zoo/jericho/configs/jericho_unizero_ddp_config.py + """ + + parser = argparse.ArgumentParser(description='Process environment configuration and launch training.') + parser.add_argument( + '--env', + type=str, + help='Identifier of the environment', + default='messenger' + ) + parser.add_argument( + '--seed', + type=int, + help='Random seed for reproducibility', + default=0 + ) + args = parser.parse_args() + + # Start the main process with the provided arguments + main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/messenger/configs/messenger_unizero_ddp_config.py b/zoo/messenger/configs/messenger_unizero_ddp_config.py new file mode 100644 index 000000000..614610b03 --- /dev/null +++ b/zoo/messenger/configs/messenger_unizero_ddp_config.py @@ -0,0 +1,177 @@ +import os +import argparse +from typing import Any, Dict + +from easydict import EasyDict + + +def main(env_id: str = 'messenger', seed: int = 0, max_env_step: int = int(1e6)) -> None: + """ + Main entry point for setting up environment configurations and launching training. + + Args: + env_id (str): Identifier of the environment, e.g., 'detective.z5'. + seed (int): Random seed used for reproducibility. + + Returns: + None + """ + gpu_num = 4 + collector_env_num: int = 1 # Number of collector environments + n_episode = int(collector_env_num*gpu_num) + batch_size = int(64*gpu_num) + env_id = 'messenger' + action_space_size = 5 + max_steps = 100 + use_manual=True + task='s1' + max_seq_len=256 + + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + evaluator_env_num: int = 1 # Number of evaluator environments + num_simulations: int = 50 # Number of simulations + + # Project training parameters + num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion) + infer_context_length: int = 4 # Inference context length + + num_layers: int = 2 # Number of layers in the model + replay_ratio: float = 0.1 # Replay ratio for experience replay + embed_dim: int = 768 # Embedding dimension + + buffer_reanalyze_freq: float = 1 / 100000 + # reanalyze_batch_size: Number of sequences to reanalyze per reanalysis process + reanalyze_batch_size: int = 160 + # reanalyze_partition: Partition ratio from the replay buffer to use during reanalysis + reanalyze_partition: float = 0.75 + + model_name: str = 'BAAI/bge-base-en-v1.5' + + messenger_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(17, 10, 10), + max_steps=max_steps, + max_action_num=5, + n_entities=17, + mode='train', + task=task, + max_seq_len=max_seq_len, + model_path=model_name, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_gpu=True, + use_wandb=False, + accumulation_steps=1, + model=dict( + observation_shape=(17, 10, 10), + action_space_size=action_space_size, + downsample=False, + continuous_action_space=False, + image_channel=17, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=embed_dim, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + use_manual=use_manual, + manual_embed_dim=768, + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=int(collector_env_num*max_steps*replay_ratio ), + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_simulations=num_simulations, + n_episode=n_episode, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + ) + messenger_unizero_config = EasyDict(messenger_unizero_config) + main_config = messenger_unizero_config + + messenger_unizero_create_config = dict( + env=dict( + type='messenger', + import_names=['zoo.messenger.envs.messenger_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + messenger_unizero_create_config = EasyDict(messenger_unizero_create_config) + create_config = messenger_unizero_create_config + from ding.utils import DDPContext + from lzero.config.utils import lz_to_ddp_config + with DDPContext(): + main_config = lz_to_ddp_config(main_config) + # Construct experiment name containing key parameters + main_config.exp_name = ( + f"data_lz/data_unizero_messenger/{env_id}_use_manual_{use_manual}/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" + f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" + f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" + ) + from lzero.entry import train_unizero + # Launch the training process + train_unizero( + [main_config, create_config], + seed=seed, + model_path=main_config.policy.model_path, + max_env_step=max_env_step, + ) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + torchrun --nproc_per_node=4 ./zoo/jericho/configs/jericho_unizero_ddp_config.py + """ + + parser = argparse.ArgumentParser(description='Process environment configuration and launch training.') + parser.add_argument( + '--env', + type=str, + help='Identifier of the environment', + default='messenger' + ) + parser.add_argument( + '--seed', + type=int, + help='Random seed for reproducibility', + default=0 + ) + args = parser.parse_args() + + # Start the main process with the provided arguments + main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/messenger/envs/messenger_env.py b/zoo/messenger/envs/messenger_env.py new file mode 100644 index 000000000..c8c6e6410 --- /dev/null +++ b/zoo/messenger/envs/messenger_env.py @@ -0,0 +1,347 @@ +import logging +import copy +import os +import json +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from gym import spaces +import os +import sys +# import embodied +import pickle +from PIL import Image, ImageDraw, ImageFont +from transformers import AutoTokenizer, AutoModel +import torch + +from ding.utils import ENV_REGISTRY, get_rank, get_world_size +from ding.torch_utils import to_ndarray +from ding.envs import BaseEnv, BaseEnvTimestep +from easydict import EasyDict +from messenger.envs.stage_one import StageOne +from messenger.envs.stage_two import StageTwo +from messenger.envs.stage_three import StageThree +from messenger.envs.wrappers import TwoEnvWrapper +from messenger.envs.config import STATE_HEIGHT, STATE_WIDTH +# import from_gym + + +@ENV_REGISTRY.register('messenger') +class Messenger(BaseEnv): + + + tokenizer: Optional[AutoTokenizer] = None + manual_encoder: Optional[AutoModel] = None + manual_embeds: Optional[torch.Tensor] = None + + config = dict( + model_path="BAAI/bge-base-en-v1.5", + # (int) The number of environment instances used for data collection. + collector_env_num=1, + # (int) The number of environment instances used for evaluator. + evaluator_env_num=1, + # (int) The number of episodes to evaluate during each evaluation period. + n_evaluator_episode=1, + # (str) The type of the environment, here it's Messenger. + env_type='Messenger', + n_entities=17, + observation_shape=(17, STATE_HEIGHT, STATE_WIDTH), + max_seq_len=256, # all manual sentence + gray_scale=True, + channel_last=False, + max_steps=100, + stop_value=int(1e6), + max_action_num=5, + mode="train", + task="s1" + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + """ + Overview: + Return the default configuration for the Atari LightZero environment. + Arguments: + - cls (:obj:`type`): The class AtariEnvLightZero. + Returns: + - cfg (:obj:`EasyDict`): The default configuration dictionary. + """ + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: Dict[str, Any]): + """ + Overview: + Initialize the Messenger environment. + + Arguments: + - cfg (:obj:`Dict[str, Any]`): Configuration dictionary containing keys like max_steps, game_path, etc. + """ + merged_cfg = self.default_config() + merged_cfg.update(cfg) + self.cfg = merged_cfg + self._init_flag = False + self.channel_last = self.cfg.channel_last + self._timestep = 0 + + self.max_steps = self.cfg.max_steps + self.max_seq_len = self.cfg.max_seq_len + self.n_entities = self.cfg.n_entities + self.task = self.cfg.task + self.mode = self.cfg.mode + self.max_action_num = self.cfg.max_action_num + + self.manual = None + self._init_flag = False + self._eval_episode_return = 0.0 + + # Get current world size and rank for distributed setups. + self.world_size: int = get_world_size() + self.rank: int = get_rank() + + if Messenger.tokenizer is None: + if self.rank == 0: + Messenger.tokenizer = AutoTokenizer.from_pretrained(self.cfg['model_path']) + Messenger.manual_encoder = AutoModel.from_pretrained(self.cfg['model_path']) + if self.world_size > 1: + # Wait until rank 0 finishes loading the tokenizer + torch.distributed.barrier() + if self.rank != 0: + Messenger.tokenizer = AutoTokenizer.from_pretrained(self.cfg['model_path']) + Messenger.manual_encoder = AutoModel.from_pretrained(self.cfg['model_path']) + + print(f"Messenger config: {self.task} {self.mode} max_steps {self.max_steps}") + assert self.task in ("s1", "s2", "s3") + assert self.mode in ("train", "eval") + + if self.task == "s1": + if self.mode == "train": + self._env = TwoEnvWrapper( + stage=1, + split_1='train-mc', + split_2='train-sc', + ) + else: + self._env = StageOne(split="val") + elif self.task == "s2": + if self.mode == "train": + self._env = TwoEnvWrapper( + stage=2, + split_1='train-sc', + split_2='train-mc' + ) + else: + self._env = StageTwo(split='val') + elif self.task == "s3": + if self.mode == "train": + self._env = TwoEnvWrapper( + stage=3, + split_1='train-mc', + split_2='train-sc', + ) + else: + self._env = StageThree(split='val') + + observation_space = ( + self.cfg.observation_shape[0], + self.cfg.observation_shape[1], + self.cfg.observation_shape[2] + ) + self._observation_space = spaces.Dict({ + 'observation': spaces.Box( + low=0, high=1, shape=observation_space, dtype=np.float32 + ), + 'action_mask': spaces.Box( + low=0, high=1, shape=(self.cfg.max_action_num,), dtype=np.int8 + ), + 'to_play': spaces.Box( + low=-1, high=2, shape=(), dtype=np.int8 + ), + 'timestep': spaces.Box( + low=0, high=int(1.08e5), shape=(), dtype=np.int32 + ), + }) + self._action_space = spaces.Discrete(int(self.max_action_num)) + self.reward_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32) + + + + def __repr__(self) -> str: + """ + Overview: + Return a string representation of the environment. + + Returns: + - (:obj:`str`): String representation of the environment. + """ + return "LightZero Messenger Env" + + @property + def observation_space(self) -> spaces.Space: + """ + Property to access the observation space of the environment. + """ + return self._observation_space + + @property + def action_space(self) -> spaces.Space: + """ + Property to access the action space of the environment. + """ + return self._action_space + + def _symbolic_to_multihot(self, obs): + # (h, w, 2) + layers = np.concatenate((obs["entities"], obs["avatar"]), + axis=-1).astype(int) + new_ob = np.maximum.reduce([np.eye(self.n_entities)[layers[..., i]] for i + in range(layers.shape[-1])]) + new_ob[:, :, 0] = 0 + return new_ob.astype(np.float32) + + def observe(self) -> dict: + """ + Overview: + Return the current observation along with the action mask and to_play flag. + Returns: + - observation (:obj:`dict`): The dictionary containing current observation, action mask, and to_play flag. + """ + observation = self.obs + + if not self.channel_last: + # move the channel dim to the fist axis + # (10, 10, 17) -> (17, 10, 10) + observation = np.transpose(observation, (2, 0, 1)) + action_mask = np.ones(self.max_action_num, dtype=np.int8) + # return {'observation': {'image': observation, 'manual_embeds': self.manual_embeds}, 'action_mask': action_mask, 'to_play': np.array(-1), 'timestep': np.array(self._timestep), } + return {'observation': observation, 'action_mask': action_mask, 'to_play': np.array(-1), 'timestep': np.array(self._timestep), 'manual_embeds': self.manual_embeds} + + def reset(self): + self._init_flag = True + self._eval_episode_return = 0.0 + self._timestep = 0 + obs, self.manual = self._env.reset() + maunal_sentence = ' '.join(self.manual) + + tokenized_output = self.tokenizer( + [maunal_sentence], truncation=True, padding="max_length", max_length=self.max_seq_len, return_tensors='pt') + # ts = {k: v.to(self.device) for k, v in ts.items()} + with torch.no_grad(): + self.manual_embeds = self.manual_encoder(**tokenized_output).last_hidden_state[:,0,:].squeeze() + + obs["observation"] = self._symbolic_to_multihot(obs) + del obs["entities"] + del obs["avatar"] + self.obs = to_ndarray(obs['observation']) + obs = self.observe() + return obs + + def step(self, action: int) -> BaseEnvTimestep: + """ + Overview: + Execute the given action and return the resulting environment timestep. + Arguments: + - action (:obj:`int`): The action to be executed. + Returns: + - timestep (:obj:`BaseEnvTimestep`): The environment timestep after executing the action. + """ + obs, reward, done, info = self._env.step(action) + new_obs = self._symbolic_to_multihot(obs) + self.obs = to_ndarray(new_obs) + + self._timestep += 1 # don't increment step while reading + self._eval_episode_return += reward + + observation = self.observe() + if info is None: + info = {} + if self._timestep >= self.max_steps: + done = True + + if done: + print('=' * 20) + print(f'rank {self.rank} one episode done! episode_return:{self._eval_episode_return}') + info['eval_episode_return'] = self._eval_episode_return + + return BaseEnvTimestep(observation, reward, done, info) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Overview: + Set the seed for the environment. + + Arguments: + - seed (:obj:`int`): The seed value. + - dynamic_seed (:obj:`bool`, optional): Whether to use a dynamic seed for randomness (defaults to True). + """ + self._seed = seed + + def close(self) -> None: + """ + Overview: + Close the environment and release any resources. + """ + self._init_flag = False + + @staticmethod + def create_collector_env_cfg(cfg: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Overview: + Create a list of environment configuration dictionaries for the collector phase. + + Arguments: + - cfg (:obj:`Dict[str, Any]`): The original environment configuration. + + Returns: + - (:obj:`List[Dict[str, Any]]`): A list of configuration dictionaries for collector environments. + """ + collector_env_num: int = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Overview: + Create a list of environment configuration dictionaries for the evaluator phase. + + Arguments: + - cfg (:obj:`Dict[str, Any]`): The original environment configuration. + + Returns: + - (:obj:`List[Dict[str, Any]]`): A list of configuration dictionaries for evaluator environments. + """ + evaluator_env_num: int = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + return [cfg for _ in range(evaluator_env_num)] + + +if __name__ == "__main__": + from easydict import EasyDict + env_type='detective' # zork1, acorncourt, detective, omniquest + # Configuration dictionary for the environment. + env_cfg = EasyDict( + dict( + max_steps=400, + max_action_num=5, + max_seq_len=512, + collector_env_num=1, + evaluator_env_num=1, + mode="train", + task="s1", + vis=False + ) + ) + env = Messenger(env_cfg) + obs = env.reset() + + while True: + action = env.action_space.sample() + obs, reward, done, info = env.step(action) + print(f"Step: {env._timestep}, Action: {action}, Reward: {reward}") + if done: + print(f"Episode done with return: {info['eval_episode_return']}") + break + del env \ No newline at end of file