diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index e107beae6..702652a83 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -111,6 +111,36 @@ def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], return torch.zeros(shape).to(device) +def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int = 0) -> torch.Tensor: + """ + Overview: + Initialize a tensor filled with `pad_token_id` for batch observations. + This function is designed to be flexible and can handle both textual + and non-textual observations: + + - For textual observations: it initializes `input_ids` with padding tokens, + ensuring consistent sequence lengths within a batch. + - For non-textual observations: it provides a convenient way to fill + observation tensors with a default of 0, + ensuring shape compatibility and preventing uninitialized values. + Arguments: + - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. + - batch_size (:obj:`int`): The batch size. + - device (:obj:`str`): The device to store the tensor. + - pad_token_id (:obj:`int`): The token ID (or placeholder value) used for padding. + Returns: + - padded_tensor (:obj:`torch.Tensor`): A tensor of the given shape, + filled with `pad_token_id`. + """ + if isinstance(observation_shape, (list, tuple)): + shape = [batch_size, *observation_shape] + elif isinstance(observation_shape, int): + shape = [batch_size, observation_shape] + else: + raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}") + + return torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device) + def random_collect( policy_cfg: 'EasyDict', # noqa policy: 'Policy', # noqa diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 61ba751a9..f7dfb040c 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -156,19 +156,22 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # For some environments (e.g., Jericho), the action space size may be different. # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), # we avoid sampling from the last `num_unroll_steps` steps of the game segment. - if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: - pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item() + if pos_in_game_segment >= len(game_segment.action_segment) - 1: + pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item() else: # For environments with a fixed action space (e.g., Atari), # we can safely sample from the entire game segment range. if pos_in_game_segment >= self._cfg.game_segment_length: pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + if pos_in_game_segment >= len(game_segment.action_segment) - 1: + pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) make_time = [time.time() for _ in range(len(batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) return orig_data diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 118f614d7..969235ebb 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -184,7 +184,7 @@ def search( current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, min_max_stats_lst, results, virtual_to_play_batch ) - + return first_action_latent_map diff --git a/lzero/model/common.py b/lzero/model/common.py index 795eb72a3..7b1bbeeae 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -15,12 +15,14 @@ import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init +from transformers import AutoModelForCausalLM, AutoTokenizer from ding.torch_utils import MLP, ResBlock from ding.torch_utils.network.normalization import build_normalization from ding.utils import SequenceType from ditk import logging from ding.utils import set_pkg_seed, get_rank, get_world_size -import torch + + def MLP_V2( in_channels: int, @@ -361,6 +363,116 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output +class QwenNetwork(nn.Module): + def __init__(self, + model_path: str = 'Qwen/Qwen3-1.7B', + embedding_size: int = 768, + final_norm_option_in_encoder: str = "layernorm", + group_size: int = 8, + tokenizer=None): + super().__init__() + + logging.info(f"Loading Qwen model from: {model_path}") + + local_rank = get_rank() + if local_rank == 0: + self.pretrained_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map={"": local_rank}, + attn_implementation="flash_attention_2" + ) + if get_world_size() > 1: + torch.distributed.barrier() + if local_rank != 0: + self.pretrained_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map={"": local_rank}, + attn_implementation="flash_attention_2" + ) + + for p in self.pretrained_model.parameters(): + p.requires_grad = False + + if tokenizer is None: + if local_rank == 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + if get_world_size() > 1: + torch.distributed.barrier() + if local_rank != 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + else: + self.tokenizer = tokenizer + + qwen_hidden_size = self.pretrained_model.config.hidden_size + + self.embedding_head = nn.Sequential( + nn.Linear(qwen_hidden_size, embedding_size), + self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size) + ) + + def _create_norm_layer(self, norm_option, embedding_size, group_size): + if norm_option.lower() == "simnorm": + return SimNorm(simnorm_dim=group_size) + elif norm_option.lower() == "layernorm": + return nn.LayerNorm(embedding_size) + else: + raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.") + + def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: + """ + Overview: + Encode the input token sequence `x` into a latent representation + using a pretrained language model backbone followed by a projection head. + Arguments: + - x (:obj:`torch.Tensor`): Input token ids of shape (B, L) + - no_grad (:obj:`bool`, optional, default=True): If True, encoding is performed under `torch.no_grad()` to save memory and computation (no gradient tracking). + Returns: + - latent (:obj:`torch.Tensor`): Encoded latent state of shape (B, D). + """ + pad_id = self.tokenizer.pad_token_id + attention_mask = (x != pad_id).long().to(x.device) + context = {'input_ids': x.long(), 'attention_mask': attention_mask} + if no_grad: + with torch.no_grad(): + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) + else: + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) + last_hidden = outputs.hidden_states[-1] + + B, L, H = last_hidden.size() + lengths = attention_mask.sum(dim=1) # [B] + positions = torch.clamp(lengths - 1, min=0) # [B] + batch_idx = torch.arange(B, device=last_hidden.device) + + selected = last_hidden[batch_idx, positions] # [B, H] + + latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype)) + return latent + + def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str: + """ + Decodes embeddings into text via the decoder network. + """ + embeddings_detached = embeddings.detach() + self.pretrained_model.eval() + + # Directly generate using provided embeddings + with torch.no_grad(): + param = next(self.pretrained_model.parameters()) + embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype) + gen_ids = self.pretrained_model.generate( + inputs_embeds=embeddings, + max_length=max_length + ) + texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) + self.pretrained_model.train() + return texts[0] if len(texts) == 1 else texts + + def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: + return self.encode(x, no_grad=no_grad) + class HFLanguageRepresentationNetwork(nn.Module): def __init__(self, @@ -542,7 +654,6 @@ def __init__( else: raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") - def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 4ea6500f3..7a9ec84d6 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -8,7 +8,7 @@ from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ - HFLanguageRepresentationNetwork + HFLanguageRepresentationNetwork, QwenNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size @@ -96,21 +96,37 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) elif world_model_cfg.obs_type == 'text': - self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) - # print(self.representation_network.model.encoder.layer[0].attention.output.LayerNorm.weight) - - if self.rank == 0: - self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") - self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") - if self.world_size > 1: - # Wait until rank 0 finishes loading the tokenizer - torch.distributed.barrier() - if self.rank != 0: - self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") - self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") - - projection = [self.representation_network.pretrained_model.config.hidden_size, self.decoder_network.config.d_model] - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, with_lpips=False, projection=projection) + if kwargs['encoder_option'] == 'legacy': + self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) + if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none': + self.decoder_network = None + self.decoder_network_tokenizer = None + projection = None + else: + if self.rank == 0: + self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") + self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") + if self.world_size > 1: + # Wait until rank 0 finishes loading the tokenizer + torch.distributed.barrier() + if self.rank != 0: + self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") + self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") + projection = [world_model_cfg.embed_dim, self.decoder_network.config.d_model] + elif kwargs['encoder_option'] == 'qwen': + self.representation_network = QwenNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) + if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none': + self.decoder_network = None + self.decoder_network_tokenizer = None + projection = None + else: + projection = [world_model_cfg.embed_dim, self.representation_network.pretrained_model.config.hidden_size] + self.decoder_network = self.representation_network + self.decoder_network_tokenizer = None + else: + raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, + with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index bbc4e6c87..e5e18461f 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -39,7 +39,7 @@ class Tokenizer(nn.Module): Can operate on visual or textual data, supporting optional LPIPS perceptual loss. It optionally includes a linear projection layer and can be paired with a decoder tokenizer. """ - def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer=None, with_lpips: bool = False, projection: list = None) -> None: + def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer=None, with_lpips: bool = False, projection: list = None, encoder_option='legacy') -> None: """Initialize the Tokenizer. Arguments: @@ -49,6 +49,7 @@ def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer with_lpips (bool, optional): If True, enable perceptual loss computation via LPIPS. Defaults to False. projection (list[int], optional): If provided, defines a linear projection layer from projection[0] → projection[1]. If None, an identity layer is used. + encoder_option (str, optional): Option to specify the encoder type, e.g., 'legacy' for T5 decoder or 'qwen' for Qwen decoder. Defaults to 'legacy'. """ super().__init__() if with_lpips: @@ -59,27 +60,14 @@ def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer self.encoder = encoder self.decoder_network = decoder_network - self.decoder_network_tokenizer = decoder_network_tokenizer + self.decoder_network_tokenizer = decoder_network_tokenizer + self.encoder_option = encoder_option if projection is None: self.projection_layer = nn.Identity() else: self.projection_layer = nn.Linear(projection[0], projection[1]) - - def decode_to_plain_text(self, x) -> str: - """ - Decode the input tensor to plain text. - - Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). - - Returns: - str: Decoded plain text. - """ - # Convert the input tensor to a numpy array and decode it - return self.encoder.tokenizer.batch_decode(x, skip_special_tokens=True) - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: """ Encode observations to embeddings. @@ -146,34 +134,77 @@ def decode_to_reconstruction_outputs(self, embeddings: torch.Tensor, target_ids: embeddings = embeddings.reshape(B*T,1,E) target_ids = target_ids.reshape(B*T, -1) - # Instead of using raw target_ids, convert them to plain text and re-tokenize using the decoder's tokenizer. - # This guarantees alignment with the decoder's vocabulary, special tokens, and tokenization rules. - text_list = self.decode_to_plain_text(target_ids) - t5_target_ids = self.decoder_network_tokenizer(text_list, - padding="max_length", - truncation=True, - max_length=512, - return_tensors="pt") - labels = t5_target_ids.input_ids - labels[labels == self.decoder_network_tokenizer.pad_token_id] = -100 - - embeddings = self.projection_layer(embeddings) # (B', 1, E) -> (B', 1, E'), B' = B*T - encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) - encoder_attention_mask = torch.ones( - embeddings.size(0), embeddings.size(1), - device=embeddings.device, dtype=torch.long - ) - - labels = labels.to(embeddings.device) - - outputs = self.decoder_network(encoder_outputs=encoder_outputs_tuple, - attention_mask=encoder_attention_mask, - labels=labels, - return_dict=True) - - return outputs + if self.encoder_option == 'legacy': # T5 decoder + # Instead of using raw target_ids, convert them to plain text and re-tokenize using the decoder's tokenizer. + # This guarantees alignment with the decoder's vocabulary, special tokens, and tokenization rules. + text_list = self.encoder.tokenizer.batch_decode(target_ids, skip_special_tokens=True) + t5_target_ids = self.decoder_network_tokenizer(text_list, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt") + labels = t5_target_ids.input_ids + labels[labels == self.decoder_network_tokenizer.pad_token_id] = -100 + + embeddings = self.projection_layer(embeddings) # (B', 1, E) -> (B', 1, E'), B' = B*T + encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) + encoder_attention_mask = torch.ones( + embeddings.size(0), embeddings.size(1), + device=embeddings.device, dtype=torch.long + ) + + labels = labels.to(embeddings.device) + + outputs = self.decoder_network(encoder_outputs=encoder_outputs_tuple, + attention_mask=encoder_attention_mask, + labels=labels, + return_dict=True) + return outputs + + elif self.encoder_option == 'qwen': + hidden = self.projection_layer(embeddings) + lm = self.decoder_network.pretrained_model + # Get a reference parameter for device/dtype info + param = next(lm.parameters()) + + try: + # Retrieve the input embedding layer of the language model + input_embedding_layer = lm.get_input_embeddings() + except: + raise ValueError('Error... Could not retrieve input embedding layer from the decoder network.') + + # Convert target token IDs into embeddings using the LM's input embedding layer + target_embeds = input_embedding_layer(target_ids) + + # Concatenate the projected hidden embeddings (prompt) with target embeddings + # hidden: (B, 1, D), target_embeds: (B, L, D) → inputs_embeds: (B, 1+L, D) + inputs_embeds = torch.cat([hidden, target_embeds.detach()], dim=1) + + inputs_embeds = inputs_embeds.to(device=param.device, dtype=param.dtype) + + prompt_attention_mask = torch.ones(hidden.size(0), 1, device=param.device, dtype=torch.long) + target_attention_mask = (target_ids != self.decoder_network.tokenizer.pad_token_id).to(device=param.device, dtype=torch.long) + # Concatenate prompt mask and target mask along sequence length + attention_mask = torch.cat([prompt_attention_mask, target_attention_mask], dim=1) + # Construct labels: for the prompt part, use -100 (ignored by loss function) + prompt_labels = torch.full((hidden.size(0), 1), -100, device=param.device, dtype=torch.long) + + # Copy target token IDs as labels, masking pad positions with -100 + labels = target_ids.clone().to(param.device) + labels[labels == self.decoder_network.tokenizer.pad_token_id] = -100 + + final_labels = torch.cat([prompt_labels, labels], dim=1) + + outputs = lm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=final_labels, + return_dict=True + ) + + return outputs - def decode_to_plain_text_for_decoder( + def decode_to_plain_text( self, embeddings: torch.Tensor, max_length: int = 512 ) -> List[List[int]]: @@ -209,29 +240,31 @@ def decode_to_plain_text_for_decoder( embeddings = embeddings.unsqueeze(1) embeddings = self.projection_layer(embeddings) - - encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) - encoder_attention_mask = torch.ones( - embeddings.size(0), embeddings.size(1), - device=device, dtype=torch.long - ) - - # Use the decoder's generate() method to autoregressively decode text from the input embeddings. - # The projected embeddings serve as encoder outputs in a typical encoder-decoder architecture, - # where the decoder attends to them via cross-attention at each step until max_length or EOS is reached. - generated_t5_ids = self.decoder_network.generate( - encoder_outputs=encoder_outputs_tuple, - attention_mask=encoder_attention_mask, - max_length=max_length - ) - - # Convert the generated output to a list of strings on CPU, skipping special tokens. - generated_text = self.decoder_network_tokenizer.batch_decode( - generated_t5_ids, skip_special_tokens=True) - - assert len(generated_text) == 1, f"Expected 1 generated text, got {len(generated_text)}" + if self.encoder_option == 'legacy': # T5 decoder + encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) + encoder_attention_mask = torch.ones( + embeddings.size(0), embeddings.size(1), + device=device, dtype=torch.long + ) + + # Use the decoder's generate() method to autoregressively decode text from the input embeddings. + # The projected embeddings serve as encoder outputs in a typical encoder-decoder architecture, + # where the decoder attends to them via cross-attention at each step until max_length or EOS is reached. + generated_t5_ids = self.decoder_network.generate( + encoder_outputs=encoder_outputs_tuple, + attention_mask=encoder_attention_mask, + max_length=max_length + ) + + # Convert the generated output to a list of strings on CPU, skipping special tokens. + generated_text = self.decoder_network_tokenizer.batch_decode( + generated_t5_ids, skip_special_tokens=True) + + assert len(generated_text) == 1, f"Expected 1 generated text, got {len(generated_text)}" + return generated_text[0] - return generated_text[0] + elif self.encoder_option == 'qwen': + return self.decoder_network.decode(embeddings=embeddings, max_length=max_length) @staticmethod def reconstruction_loss(original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index e8df2a6e0..7f1a0f68e 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -100,7 +100,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: skip_modules = set() if hasattr(self.tokenizer.encoder, 'pretrained_model'): skip_modules.update(self.tokenizer.encoder.pretrained_model.modules()) - if hasattr(self.tokenizer, 'decoder_network'): + if hasattr(self.tokenizer, 'decoder_network') and self.tokenizer.decoder_network is not None: skip_modules.update(self.tokenizer.decoder_network.modules()) def custom_init(module): @@ -1372,7 +1372,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar if decode_loss_mode == "after_backbone": next_latent_state = outputs.logits_observations[:, :-1, :] next_target_ids = batch['observations'][:, 1:, :] - + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( embeddings=next_latent_state, target_ids=next_target_ids, @@ -1506,9 +1506,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Compute discount coefficients for each timestep discounts = self.gamma ** timesteps - if batch['mask_padding'].sum() == 0: - assert False, "mask_padding is all zeros" - # Group losses into first step, middle step, and last step first_step_losses = {} middle_step_losses = {} @@ -1547,7 +1544,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Discount reconstruction loss and perceptual loss discounted_latent_recon_loss = latent_recon_loss discounted_perceptual_loss = perceptual_loss - # Calculate overall discounted loss discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 9ff2c1333..bbf27a0e2 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -8,7 +8,7 @@ from ding.model import model_wrap from ding.utils import POLICY_REGISTRY -from lzero.entry.utils import initialize_zeros_batch +from lzero.entry.utils import initialize_zeros_batch, initialize_pad_batch from lzero.mcts import UniZeroMCTSCtree as MCTSCtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ @@ -129,6 +129,10 @@ class UniZeroPolicy(MuZeroPolicy): rope_theta=10000, # (int) The maximum sequence length for position encoding. max_seq_len=8192, + # Controls where to compute reconstruction loss: 'after_backbone', 'before_backbone', or None. + # - after_backbone: The reconstruction loss is computed after the encoded representation passes through the backbone. + # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. + decode_loss_mode=None, ), ), # ****** common ****** @@ -348,6 +352,9 @@ def _init_learn(self) -> None: self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. + + encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 if self._cfg.use_wandb: # TODO: add the model to wandb @@ -592,7 +599,9 @@ def _init_collect(self) -> None: self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_obs = torch.full( + [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, + ).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] # @profile @@ -694,9 +703,9 @@ def _forward_collect( next_latent_state = next_latent_state_with_env[i][action] - if self._cfg.model.world_model_cfg.obs_type == 'text': + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': # Output the plain text content decoded by the decoder from the next latent state - predicted_next = self._collect_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + predicted_next = self._collect_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) else: predicted_next = None @@ -745,11 +754,13 @@ def _init_eval(self) -> None: self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + self.last_batch_obs = torch.full( + [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, + ).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], ready_env_id: np.array = None, timestep: List = [0]) -> Dict: @@ -825,9 +836,9 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ # Predict the next latent state based on the selected action and policy next_latent_state = next_latent_state_with_env[i][action] - if self._cfg.model.world_model_cfg.obs_type == 'text': + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': # Output the plain text content decoded by the decoder from the next latent state - predicted_next = self._eval_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + predicted_next = self._eval_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) else: predicted_next = None @@ -861,10 +872,11 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_zeros_batch( + self.last_batch_obs = initialize_pad_batch( self._cfg.model.observation_shape, self._cfg.collector_env_num, - self._cfg.device + self._cfg.device, + pad_token_id=self.pad_token_id ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] @@ -905,10 +917,11 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_zeros_batch( + self.last_batch_obs = initialize_pad_batch( self._cfg.model.observation_shape, self._cfg.evaluator_env_num, - self._cfg.device + self._cfg.device, + pad_token_id=self.pad_token_id ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 30fbe5a7e..cc66e045b 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -59,7 +59,14 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e reanalyze_partition: float = 0.75 # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") # ------------------------------------------------------------------ # TODO: Debug configuration - override some parameters for debugging purposes @@ -104,6 +111,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e model=dict( observation_shape=512, action_space_size=action_space_size, + encoder_option=encoder_option, encoder_url=model_name, model_type="mlp", continuous_action_space=False, diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index 5cb67a8f8..4fefd717d 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -16,10 +16,11 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e Returns: None """ - gpu_num = 2 + gpu_num = 4 collector_env_num: int = 4 # Number of collector environments n_episode = int(collector_env_num*gpu_num) - batch_size = int(8*gpu_num) + batch_size = int(1*gpu_num) + accumulation_steps=1 # TODO # batch_size = batch_size * 2 @@ -35,9 +36,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e 'acorncourt.z5': (45, 50), 'zork1.z5': (55, 500), } - env_id = 'detective.z5' - # Set action_space_size and max_steps based on env_id action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found @@ -64,7 +63,14 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e reanalyze_partition: float = 0.75 # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") # ------------------------------------------------------------------ # TODO: Debug configuration - override some parameters for debugging purposes @@ -105,11 +111,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=4, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=accumulation_steps, # TODO: Accumulated gradient steps (currently default) model=dict( observation_shape=512, action_space_size=action_space_size, encoder_url=model_name, + encoder_option=encoder_option, model_type="mlp", continuous_action_space=False, world_model_cfg=dict( @@ -129,12 +136,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e embed_dim=embed_dim, obs_type="text", # TODO: Modify as needed. env_num=max(collector_env_num, evaluator_env_num), - decode_loss_mode='after_backbone', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + decode_loss_mode='None', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. latent_recon_loss_weight=0.1 # TODO: decoder loss weight ), ), # TODO - update_per_collect=int(collector_env_num*max_steps*replay_ratio*4 ), # Important for DDP + update_per_collect=int(collector_env_num*max_steps*replay_ratio*accumulation_steps), # Important for DDP action_type="varied_action_space", model_path=None, num_unroll_steps=num_unroll_steps, @@ -193,7 +200,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e main_config = lz_to_ddp_config(main_config) # Construct experiment name containing key parameters main_config.exp_name = ( - f"data_lz/data_unizero_jericho/bge-base-en-v1.5/{env_id}/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" + f"data_lz/data_unizero_jericho/{model_name}/{env_id}/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" ) diff --git a/zoo/jericho/configs/jericho_unizero_segment_config.py b/zoo/jericho/configs/jericho_unizero_segment_config.py index 6d7c4768b..a44b9cf75 100644 --- a/zoo/jericho/configs/jericho_unizero_segment_config.py +++ b/zoo/jericho/configs/jericho_unizero_segment_config.py @@ -22,7 +22,16 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: # Frequently changed configurations (user-specified) # ============================================================== # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") + + collector_env_num = 8 game_segment_length = 20 evaluator_env_num = 5 @@ -86,6 +95,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: model=dict( observation_shape=512, action_space_size=action_space_size, + encoder_option=encoder_option, encoder_url=model_name, model_type="mlp", world_model_cfg=dict( @@ -104,6 +114,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: embed_dim=embed_dim, obs_type="text", env_num=max(collector_env_num, evaluator_env_num), + decode_loss_mode='None', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + latent_recon_loss_weight=0.1 ), ), action_type="varied_action_space",