From 4e2fe99132231260f42beb7f52ac417f6059b2fb Mon Sep 17 00:00:00 2001 From: puyuan Date: Fri, 18 Jul 2025 07:55:53 +0000 Subject: [PATCH 1/5] Qwen is tested as a policy in the jericho environment --- zoo/jericho/envs/test_qwen.py | 234 ++++++++++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 zoo/jericho/envs/test_qwen.py diff --git a/zoo/jericho/envs/test_qwen.py b/zoo/jericho/envs/test_qwen.py new file mode 100644 index 000000000..bb3eb0220 --- /dev/null +++ b/zoo/jericho/envs/test_qwen.py @@ -0,0 +1,234 @@ +import logging +import copy +import os +import json +from datetime import datetime +from typing import Any, Dict, List, Optional, Union +from easydict import EasyDict +from collections import deque + +import numpy as np +import torch +from transformers import AutoTokenizer + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +import torch.distributed as dist +import torch + +from jericho_env import JerichoEnv + + +def init_distributed(): + if not dist.is_initialized(): + dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo') + rank = dist.get_rank() + world_size = dist.get_world_size() + return rank, world_size + +class Qwen3Policy: + def __init__(self, model_path=None, local_rank=0): + self.device = torch.device(f"cuda:{local_rank}") + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to(self.device) + self.max_history_len = 5 + + def format_prompt(self, history: List[Dict[str, str]], current_obs: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None) -> str: + system_prompt = ( + "You are an expert at playing text-based games.\n" + "You will be given a history of game states and actions taken, as well as your memory from previous gameplay.\n" + "Use this information to select the best next action.\n\n" + ) + if good_memory: + system_prompt += "Good memory from past games:\n" + for mem in good_memory: + system_prompt += f"- {mem}\n" + else: + system_prompt += "Good memory from past games: None\n" + + if bad_memory: + system_prompt += "Bad memory from past failures:\n" + for mem in bad_memory: + system_prompt += f"- {mem}\n" + else: + system_prompt += "Bad memory from past failures: None\n" + + + system_prompt += "\nHistory:\n" + if history is None or len(history) == 0: + system_prompt += "None\n" + for h in history: + system_prompt += f"[State]: {h['obs']}\n[Action]: {h['action']}\n" + + state = current_obs.split('Valid actions:')[0] + + system_prompt += ( + f"\nCurrent:\n[State]: {state}\n" + f"[Valid Actions]: {', '.join(valid_actions)}\n" + "Please choose the best next action from the valid actions above, and answer with only one word or phrase, without any explanation.\n" + "[Action]:" + ) + return system_prompt + + def generate_reflection(self, history: List[Dict[str, str]], positive: bool) -> str: + trajectory_str = "\n".join([f"[State]: {h['obs']}\n[Action]: {h['action']}" for h in history]) + if positive: + prompt = ( + "You will receive a log of successful gameplay from a text-based adventure game.\n" + "Summarize a good strategy or useful lesson learned from the following playthrough in one sentence.\n" + f"{trajectory_str}\n" + ) + else: + prompt = ( + "You will receive a log of unsuccessful gameplay from a text-based adventure game.\n" + "Please identify the reasons for failure and provide a short suggestion for improving the strategy next time.\n" + "Respond with one sentence only.\n" + f"{trajectory_str}\n" + ) + + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + gen_config = GenerationConfig( + temperature=0.7, + top_p=0.9, + do_sample=True, + max_new_tokens=64, + pad_token_id=self.tokenizer.eos_token_id + ) + output = self.model.generate( + **model_inputs, + generation_config=gen_config + ) + output_ids = output[0][len(model_inputs.input_ids[0]):].tolist() + content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip() + return content + + def sample_action(self, history: List[Dict[str, str]], current_obs: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None) -> str: + prompt = self.format_prompt(history, current_obs, valid_actions, good_memory, bad_memory) + messages = [ + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + gen_config = GenerationConfig( + temperature=0.7, + top_p=0.9, + do_sample=True, + max_new_tokens=64, + pad_token_id=self.tokenizer.eos_token_id + ) + generated_ids = self.model.generate( + **model_inputs, + generation_config=gen_config + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") + + res_action = None + + for va in valid_actions: + if va.lower() in content: + res_action = va + if res_action is None: + if valid_actions: + res_action = valid_actions[0] + else: + res_action = 'go' + + return res_action, prompt + +if __name__ == '__main__': + rank, world_size = init_distributed() + print(f"[RANK {rank}] Initialized. World size: {world_size}") + + + # env_type='detective' # zork1, acorncourt, detective, omniquest + env_type='zork1' + model_name = "Qwen2.5-7B-Instruct" # Path to the Qwen model + # Configuration dictionary for the environment. + env_cfg = EasyDict( + dict( + max_steps=100, + game_path="./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/" + f"{env_type}.z5", + max_action_num=12, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512, + remove_stuck_actions=False, + # add_location_and_inventory=True, # TODO 尝试打开或者不打开该参数 + add_location_and_inventory=False, # TODO 尝试打开或者不打开该参数 + for_unizero=False, + collector_env_num=1, + evaluator_env_num=1, + save_replay=False, + save_replay_path=None, + env_type=env_type, + collect_policy_mode='expert' # random, human, expert + ) + ) + + if env_cfg.add_location_and_inventory: + log_dir = f'/fs-computility/niuyazhe/shared/xiongjyu/jericho/LightZero/log/{model_name}/{env_type}_add_locAndinv' + else: + log_dir = f'/fs-computility/niuyazhe/shared/xiongjyu/jericho/LightZero/log/{model_name}/{env_type}' + os.makedirs(log_dir, exist_ok=True) + + log_file = os.path.join(log_dir, f"rank_{rank}.txt") + f = open(log_file, "w", encoding="utf-8") + + env = JerichoEnv(env_cfg) + qwen_policy = Qwen3Policy(model_path=f"/fs-computility/niuyazhe/shared/xiongjyu/model/{model_name}", local_rank=rank) + history = deque(maxlen=qwen_policy.max_history_len) + + num_episodes = 20 # 可设置为任意 N + good_trial_memory = deque(maxlen=5) + bad_trial_memory = deque(maxlen=5) + + for episode_id in range(num_episodes): + f.write(f"{'='*60}\n") + f.write(f'current episode: {episode_id}\n') + f.write(f"{'='*60}\n") + f.flush() + obs = env.reset(return_str=True) + done = False + step_count = 0 + history.clear() + + while not done: + obs_str = obs['observation'] + action, prompt = qwen_policy.sample_action(history=list(history),current_obs=obs_str, valid_actions=env._action_list, + good_memory=list(good_trial_memory), bad_memory=list(bad_trial_memory)) + obs, reward, done, info = env.step(action, return_str=True) + history.append({'obs': obs_str, 'action': action}) + + # 每步写入日志 + f.write(f"Step {step_count}\n") + f.write(f"[Prompt]:\n{prompt}\n") + f.write(f"[Qwen Action]: {action}\n") + f.write(f"[Env Feedback] Reward: {reward}, Done: {done}\n") + f.write(f"{'-'*60}\n") + f.flush() + + step_count += 1 + + if "*** you have died ***" in obs_str.lower(): + reflection = qwen_policy.generate_reflection(list(history), positive=False) + bad_trial_memory.append(reflection) + f.write(f"[BAD Reflection]: {reflection}\n") + print(f'[BAD Reflection]: {reflection}') + elif "your score has just gone up by" in obs_str.lower(): + reflection = qwen_policy.generate_reflection(list(history), positive=True) + good_trial_memory.append(reflection) + f.write(f"[GOOD Reflection]: {reflection}\n") + print(f'[GOOD Reflection]: {reflection}') + + + f.write(f"Episode finished. Final return: {info.get('eval_episode_return', 0.0)}\n") + + f.close() + print(f"[RANK {rank}] Finished. Log written to {log_file}") + del env \ No newline at end of file From d96909645d2091663457ca7e9fee84ca4500ead3 Mon Sep 17 00:00:00 2001 From: puyuan Date: Thu, 24 Jul 2025 07:59:09 +0000 Subject: [PATCH 2/5] fixed the bug that bad reflection cannot be collected --- zoo/jericho/envs/test_qwen.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/zoo/jericho/envs/test_qwen.py b/zoo/jericho/envs/test_qwen.py index bb3eb0220..a3a8ea47e 100644 --- a/zoo/jericho/envs/test_qwen.py +++ b/zoo/jericho/envs/test_qwen.py @@ -79,9 +79,7 @@ def generate_reflection(self, history: List[Dict[str, str]], positive: bool) -> ) else: prompt = ( - "You will receive a log of unsuccessful gameplay from a text-based adventure game.\n" - "Please identify the reasons for failure and provide a short suggestion for improving the strategy next time.\n" - "Respond with one sentence only.\n" + "You will receive a log of unsuccessful gameplay from a text-based adventure game. Please identify the reasons for this game failure and provide a short suggestion for improving the game strategy next time. Do not summarize the gameplay trajectory; respond with your suggestion in a single sentence. For instance: 'Remember to light a lamp before entering dark areas to avoid being eaten by a grue.\n" f"{trajectory_str}\n" ) @@ -153,9 +151,9 @@ def sample_action(self, history: List[Dict[str, str]], current_obs: str, valid_a # Configuration dictionary for the environment. env_cfg = EasyDict( dict( - max_steps=100, + max_steps=500, game_path="./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/" + f"{env_type}.z5", - max_action_num=12, + max_action_num=55, tokenizer_path="google-bert/bert-base-uncased", max_seq_len=512, remove_stuck_actions=False, @@ -186,7 +184,7 @@ def sample_action(self, history: List[Dict[str, str]], current_obs: str, valid_a num_episodes = 20 # 可设置为任意 N good_trial_memory = deque(maxlen=5) - bad_trial_memory = deque(maxlen=5) + bad_trial_memory = deque(maxlen=20) for episode_id in range(num_episodes): f.write(f"{'='*60}\n") @@ -215,16 +213,17 @@ def sample_action(self, history: List[Dict[str, str]], current_obs: str, valid_a step_count += 1 - if "*** you have died ***" in obs_str.lower(): - reflection = qwen_policy.generate_reflection(list(history), positive=False) - bad_trial_memory.append(reflection) - f.write(f"[BAD Reflection]: {reflection}\n") - print(f'[BAD Reflection]: {reflection}') - elif "your score has just gone up by" in obs_str.lower(): - reflection = qwen_policy.generate_reflection(list(history), positive=True) - good_trial_memory.append(reflection) - f.write(f"[GOOD Reflection]: {reflection}\n") - print(f'[GOOD Reflection]: {reflection}') + + reflection = qwen_policy.generate_reflection(list(history), positive=False) + bad_trial_memory.append(reflection) + f.write(f"[BAD Reflection]: {reflection}\n") + print(f'[BAD Reflection]: {reflection}') + + ## 是否生产好轨迹的relection + # reflection = qwen_policy.generate_reflection(list(history), positive=True) + # good_trial_memory.append(reflection) + # f.write(f"[GOOD Reflection]: {reflection}\n") + # print(f'[GOOD Reflection]: {reflection}') f.write(f"Episode finished. Final return: {info.get('eval_episode_return', 0.0)}\n") From 00d47975a91762ad186c913b506de1c036a57909 Mon Sep 17 00:00:00 2001 From: puyuan Date: Thu, 24 Jul 2025 15:13:52 +0000 Subject: [PATCH 3/5] supports options for selecting encoder/decoder --- lzero/entry/utils.py | 22 +++ lzero/mcts/buffer/game_buffer.py | 2 +- lzero/mcts/tree_search/mcts_ctree.py | 2 +- lzero/model/common.py | 112 ++++++++++++- lzero/model/unizero_model.py | 51 ++++-- lzero/model/unizero_world_models/tokenizer.py | 155 ++++++++++-------- .../model/unizero_world_models/world_model.py | 4 +- lzero/policy/unizero.py | 33 ++-- zoo/jericho/configs/jericho_unizero_config.py | 10 +- .../configs/jericho_unizero_ddp_config.py | 24 ++- .../configs/jericho_unizero_segment_config.py | 14 +- 11 files changed, 320 insertions(+), 109 deletions(-) diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index e107beae6..2569f6bd5 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -111,6 +111,28 @@ def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], return torch.zeros(shape).to(device) +def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int) -> torch.Tensor: + """ + Overview: + Initialize a tensor filled with `pad_token_id` for batch observations. + This is typically used to initialize input_ids with padding tokens. + Arguments: + - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. + - batch_size (:obj:`int`): The batch size. + - device (:obj:`str`): The device to store the tensor. + - pad_token_id (:obj:`int`): The token ID used for padding. + Returns: + - padded_tensor (:obj:`torch.Tensor`): The tensor filled with pad_token_id. + """ + if isinstance(observation_shape, (list, tuple)): + shape = [batch_size, *observation_shape] + elif isinstance(observation_shape, int): + shape = [batch_size, observation_shape] + else: + raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}") + + return torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device) + def random_collect( policy_cfg: 'EasyDict', # noqa policy: 'Policy', # noqa diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 61ba751a9..df09cebc8 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -156,7 +156,7 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # For some environments (e.g., Jericho), the action space size may be different. # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), # we avoid sampling from the last `num_unroll_steps` steps of the game segment. - if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps: pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() else: # For environments with a fixed action space (e.g., Atari), diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 118f614d7..969235ebb 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -184,7 +184,7 @@ def search( current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, min_max_stats_lst, results, virtual_to_play_batch ) - + return first_action_latent_map diff --git a/lzero/model/common.py b/lzero/model/common.py index 795eb72a3..731891023 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -21,6 +21,8 @@ from ditk import logging from ding.utils import set_pkg_seed, get_rank, get_world_size import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + def MLP_V2( in_channels: int, @@ -361,6 +363,115 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output +class QwenNetwork(nn.Module): + def __init__(self, + model_path: str = 'Qwen/Qwen3-1.7B', + embedding_size: int = 768, + final_norm_option_in_encoder: str = "layernorm", + group_size: int = 8, + tokenizer=None): + super().__init__() + + logging.info(f"Loading Qwen model from: {model_path}") + + local_rank = get_rank() + if local_rank == 0: + self.pretrained_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map={"": local_rank}, + attn_implementation="flash_attention_2" + ) + if get_world_size() > 1: + torch.distributed.barrier() + if local_rank != 0: + self.pretrained_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map={"": local_rank}, + attn_implementation="flash_attention_2" + ) + + for p in self.pretrained_model.parameters(): + p.requires_grad = False + + if tokenizer is None: + if local_rank == 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + if get_world_size() > 1: + torch.distributed.barrier() + if local_rank != 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + else: + self.tokenizer = tokenizer + + qwen_hidden_size = self.pretrained_model.config.hidden_size + + self.embedding_head = nn.Sequential( + nn.Linear(qwen_hidden_size, embedding_size), + self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size) + ) + + def _create_norm_layer(self, norm_option, embedding_size, group_size): + if norm_option.lower() == "simnorm": + return SimNorm(simnorm_dim=group_size) + elif norm_option.lower() == "layernorm": + return nn.LayerNorm(embedding_size) + else: + raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.") + + def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: + """ + Overview: + Encode the input tensor `x` to a latent state. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (B, C_in, W, H). + Returns: + - latent_state (:obj:`torch.Tensor`): Encoded latent state of shape (B, embedding_dim). + """ + pad_id = self.tokenizer.pad_token_id + attention_mask = (x != pad_id).long().to(x.device) + context = {'input_ids': x.long(), 'attention_mask': attention_mask} + no_grad = True + if no_grad: + with torch.no_grad(): + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) + else: + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) + last_hidden = outputs.hidden_states[-1] + + B, L, H = last_hidden.size() + lengths = attention_mask.sum(dim=1) # [B] + positions = torch.clamp(lengths - 1, min=0) # [B] + batch_idx = torch.arange(B, device=last_hidden.device) + + selected = last_hidden[batch_idx, positions] # [B, H] + + latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype)) + return latent + + def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str: + """ + Decodes embeddings into text via the decoder network. + """ + embeddings_detached = embeddings.detach() + self.pretrained_model.eval() + + # Directly generate using provided embeddings + with torch.no_grad(): + param = next(self.pretrained_model.parameters()) + embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype) + gen_ids = self.pretrained_model.generate( + inputs_embeds=embeddings, + max_length=max_length + ) + texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) + self.pretrained_model.train() + return texts[0] if len(texts) == 1 else texts + + def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: + return self.encode(x, no_grad=no_grad) + class HFLanguageRepresentationNetwork(nn.Module): def __init__(self, @@ -542,7 +653,6 @@ def __init__( else: raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") - def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 4ea6500f3..d2d1b88b6 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -8,7 +8,7 @@ from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ - HFLanguageRepresentationNetwork + HFLanguageRepresentationNetwork, QwenNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size @@ -96,21 +96,37 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) elif world_model_cfg.obs_type == 'text': - self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) - # print(self.representation_network.model.encoder.layer[0].attention.output.LayerNorm.weight) - - if self.rank == 0: - self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") - self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") - if self.world_size > 1: - # Wait until rank 0 finishes loading the tokenizer - torch.distributed.barrier() - if self.rank != 0: - self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") - self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") - - projection = [self.representation_network.pretrained_model.config.hidden_size, self.decoder_network.config.d_model] - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, with_lpips=False, projection=projection) + if kwargs['encoder_option'] == 'legacy': + self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) + if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none': + self.decoder_network = None + self.decoder_network_tokenizer = None + projection = None + else: + if self.rank == 0: + self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") + self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") + if self.world_size > 1: + # Wait until rank 0 finishes loading the tokenizer + torch.distributed.barrier() + if self.rank != 0: + self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") + self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") + projection = [world_model_cfg.embed_dim, self.decoder_network.config.d_model] + elif kwargs['encoder_option'] == 'qwen': + self.representation_network = QwenNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) + if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none': + self.decoder_network = None + self.decoder_network_tokenizer = None + projection = None + else: + projection = [world_model_cfg.embed_dim, self.representation_network.pretrained_model.config.hidden_size] + self.decoder_network = self.representation_network + self.decoder_network_tokenizer = None + else: + raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, + with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -216,6 +232,9 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torc reward = logits_rewards policy_logits = logits_policy.squeeze(1) value = logits_value.squeeze(1) + if torch.isnan(value).any() or torch.isnan(latent_state).any() or torch.isnan(policy_logits).any(): + print(f'NaN detected in value, latent_state, or policy_logits at start_pos {start_pos}') + print(f'value: {value}, latent_state: {latent_state}, policy_logits: {policy_logits}') return MZNetworkOutput( value=value, diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index bbc4e6c87..7c476e6f4 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -39,7 +39,7 @@ class Tokenizer(nn.Module): Can operate on visual or textual data, supporting optional LPIPS perceptual loss. It optionally includes a linear projection layer and can be paired with a decoder tokenizer. """ - def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer=None, with_lpips: bool = False, projection: list = None) -> None: + def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer=None, with_lpips: bool = False, projection: list = None, encoder_option='legacy') -> None: """Initialize the Tokenizer. Arguments: @@ -49,6 +49,7 @@ def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer with_lpips (bool, optional): If True, enable perceptual loss computation via LPIPS. Defaults to False. projection (list[int], optional): If provided, defines a linear projection layer from projection[0] → projection[1]. If None, an identity layer is used. + encoder_option (str, optional): Option to specify the encoder type, e.g., 'legacy' for T5 decoder or 'qwen' for Qwen decoder. Defaults to 'legacy'. """ super().__init__() if with_lpips: @@ -59,27 +60,14 @@ def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer self.encoder = encoder self.decoder_network = decoder_network - self.decoder_network_tokenizer = decoder_network_tokenizer + self.decoder_network_tokenizer = decoder_network_tokenizer + self.encoder_option = encoder_option if projection is None: self.projection_layer = nn.Identity() else: self.projection_layer = nn.Linear(projection[0], projection[1]) - - def decode_to_plain_text(self, x) -> str: - """ - Decode the input tensor to plain text. - - Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). - - Returns: - str: Decoded plain text. - """ - # Convert the input tensor to a numpy array and decode it - return self.encoder.tokenizer.batch_decode(x, skip_special_tokens=True) - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: """ Encode observations to embeddings. @@ -146,34 +134,69 @@ def decode_to_reconstruction_outputs(self, embeddings: torch.Tensor, target_ids: embeddings = embeddings.reshape(B*T,1,E) target_ids = target_ids.reshape(B*T, -1) - # Instead of using raw target_ids, convert them to plain text and re-tokenize using the decoder's tokenizer. - # This guarantees alignment with the decoder's vocabulary, special tokens, and tokenization rules. - text_list = self.decode_to_plain_text(target_ids) - t5_target_ids = self.decoder_network_tokenizer(text_list, - padding="max_length", - truncation=True, - max_length=512, - return_tensors="pt") - labels = t5_target_ids.input_ids - labels[labels == self.decoder_network_tokenizer.pad_token_id] = -100 - - embeddings = self.projection_layer(embeddings) # (B', 1, E) -> (B', 1, E'), B' = B*T - encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) - encoder_attention_mask = torch.ones( - embeddings.size(0), embeddings.size(1), - device=embeddings.device, dtype=torch.long - ) - - labels = labels.to(embeddings.device) - - outputs = self.decoder_network(encoder_outputs=encoder_outputs_tuple, - attention_mask=encoder_attention_mask, - labels=labels, - return_dict=True) - - return outputs + if self.encoder_option == 'legacy': # T5 decoder + # Instead of using raw target_ids, convert them to plain text and re-tokenize using the decoder's tokenizer. + # This guarantees alignment with the decoder's vocabulary, special tokens, and tokenization rules. + text_list = self.encoder.tokenizer.batch_decode(target_ids, skip_special_tokens=True) + t5_target_ids = self.decoder_network_tokenizer(text_list, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt") + labels = t5_target_ids.input_ids + labels[labels == self.decoder_network_tokenizer.pad_token_id] = -100 + + embeddings = self.projection_layer(embeddings) # (B', 1, E) -> (B', 1, E'), B' = B*T + encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) + encoder_attention_mask = torch.ones( + embeddings.size(0), embeddings.size(1), + device=embeddings.device, dtype=torch.long + ) + + labels = labels.to(embeddings.device) + + outputs = self.decoder_network(encoder_outputs=encoder_outputs_tuple, + attention_mask=encoder_attention_mask, + labels=labels, + return_dict=True) + return outputs + + elif self.encoder_option == 'qwen': + hidden = self.projection_layer(embeddings) + lm = self.decoder_network.pretrained_model + param = next(lm.parameters()) + + try: + input_embedding_layer = lm.get_input_embeddings() + except: + raise ValueError('Error... Could not retrieve input embedding layer from the decoder network.') + + target_embeds = input_embedding_layer(target_ids) + inputs_embeds = torch.cat([hidden, target_embeds.detach()], dim=1) + + inputs_embeds = inputs_embeds.to(device=param.device, dtype=param.dtype) + + prompt_attention_mask = torch.ones(hidden.size(0), 1, device=param.device, dtype=torch.long) + target_attention_mask = (target_ids != self.decoder_network.tokenizer.pad_token_id).to(device=param.device, dtype=torch.long) + attention_mask = torch.cat([prompt_attention_mask, target_attention_mask], dim=1) + + prompt_labels = torch.full((hidden.size(0), 1), -100, device=param.device, dtype=torch.long) + + labels = target_ids.clone().to(param.device) + labels[labels == self.decoder_network.tokenizer.pad_token_id] = -100 + + final_labels = torch.cat([prompt_labels, labels], dim=1) + + outputs = lm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=final_labels, + return_dict=True + ) + + return outputs - def decode_to_plain_text_for_decoder( + def decode_to_plain_text( self, embeddings: torch.Tensor, max_length: int = 512 ) -> List[List[int]]: @@ -209,29 +232,31 @@ def decode_to_plain_text_for_decoder( embeddings = embeddings.unsqueeze(1) embeddings = self.projection_layer(embeddings) - - encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) - encoder_attention_mask = torch.ones( - embeddings.size(0), embeddings.size(1), - device=device, dtype=torch.long - ) - - # Use the decoder's generate() method to autoregressively decode text from the input embeddings. - # The projected embeddings serve as encoder outputs in a typical encoder-decoder architecture, - # where the decoder attends to them via cross-attention at each step until max_length or EOS is reached. - generated_t5_ids = self.decoder_network.generate( - encoder_outputs=encoder_outputs_tuple, - attention_mask=encoder_attention_mask, - max_length=max_length - ) - - # Convert the generated output to a list of strings on CPU, skipping special tokens. - generated_text = self.decoder_network_tokenizer.batch_decode( - generated_t5_ids, skip_special_tokens=True) - - assert len(generated_text) == 1, f"Expected 1 generated text, got {len(generated_text)}" + if self.encoder_option == 'legacy': # T5 decoder + encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) + encoder_attention_mask = torch.ones( + embeddings.size(0), embeddings.size(1), + device=device, dtype=torch.long + ) + + # Use the decoder's generate() method to autoregressively decode text from the input embeddings. + # The projected embeddings serve as encoder outputs in a typical encoder-decoder architecture, + # where the decoder attends to them via cross-attention at each step until max_length or EOS is reached. + generated_t5_ids = self.decoder_network.generate( + encoder_outputs=encoder_outputs_tuple, + attention_mask=encoder_attention_mask, + max_length=max_length + ) + + # Convert the generated output to a list of strings on CPU, skipping special tokens. + generated_text = self.decoder_network_tokenizer.batch_decode( + generated_t5_ids, skip_special_tokens=True) + + assert len(generated_text) == 1, f"Expected 1 generated text, got {len(generated_text)}" + return generated_text[0] - return generated_text[0] + elif self.encoder_option == 'qwen': + return self.decoder_network.decode(embeddings=embeddings, max_length=max_length) @staticmethod def reconstruction_loss(original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index e8df2a6e0..b58f5d9fb 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -100,7 +100,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: skip_modules = set() if hasattr(self.tokenizer.encoder, 'pretrained_model'): skip_modules.update(self.tokenizer.encoder.pretrained_model.modules()) - if hasattr(self.tokenizer, 'decoder_network'): + if hasattr(self.tokenizer, 'decoder_network') and self.tokenizer.decoder_network is not None: skip_modules.update(self.tokenizer.decoder_network.modules()) def custom_init(module): @@ -1372,7 +1372,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar if decode_loss_mode == "after_backbone": next_latent_state = outputs.logits_observations[:, :-1, :] next_target_ids = batch['observations'][:, 1:, :] - + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( embeddings=next_latent_state, target_ids=next_target_ids, diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 9ff2c1333..7ac2b070d 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -8,7 +8,7 @@ from ding.model import model_wrap from ding.utils import POLICY_REGISTRY -from lzero.entry.utils import initialize_zeros_batch +from lzero.entry.utils import initialize_zeros_batch, initialize_pad_batch from lzero.mcts import UniZeroMCTSCtree as MCTSCtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ @@ -348,6 +348,9 @@ def _init_learn(self) -> None: self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. + + encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 if self._cfg.use_wandb: # TODO: add the model to wandb @@ -592,7 +595,9 @@ def _init_collect(self) -> None: self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_obs = torch.full( + [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, + ).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] # @profile @@ -696,7 +701,7 @@ def _forward_collect( if self._cfg.model.world_model_cfg.obs_type == 'text': # Output the plain text content decoded by the decoder from the next latent state - predicted_next = self._collect_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + predicted_next = self._collect_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) else: predicted_next = None @@ -745,11 +750,13 @@ def _init_eval(self) -> None: self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + self.last_batch_obs = torch.full( + [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, + ).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], ready_env_id: np.array = None, timestep: List = [0]) -> Dict: @@ -827,7 +834,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if self._cfg.model.world_model_cfg.obs_type == 'text': # Output the plain text content decoded by the decoder from the next latent state - predicted_next = self._eval_model.tokenizer.decode_to_plain_text_for_decoder(embeddings=next_latent_state, max_length=256) + predicted_next = self._eval_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) else: predicted_next = None @@ -861,10 +868,11 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_zeros_batch( + self.last_batch_obs = initialize_pad_batch( self._cfg.model.observation_shape, self._cfg.collector_env_num, - self._cfg.device + self._cfg.device, + pad_token_id=self.pad_token_id ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] @@ -905,10 +913,11 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_zeros_batch( + self.last_batch_obs = initialize_pad_batch( self._cfg.model.observation_shape, self._cfg.evaluator_env_num, - self._cfg.device + self._cfg.device, + pad_token_id=self.pad_token_id ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 30fbe5a7e..cc66e045b 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -59,7 +59,14 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e reanalyze_partition: float = 0.75 # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") # ------------------------------------------------------------------ # TODO: Debug configuration - override some parameters for debugging purposes @@ -104,6 +111,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e model=dict( observation_shape=512, action_space_size=action_space_size, + encoder_option=encoder_option, encoder_url=model_name, model_type="mlp", continuous_action_space=False, diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index 5cb67a8f8..e78196996 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -16,10 +16,10 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e Returns: None """ - gpu_num = 2 + gpu_num = 4 collector_env_num: int = 4 # Number of collector environments n_episode = int(collector_env_num*gpu_num) - batch_size = int(8*gpu_num) + batch_size = int(64*gpu_num) # TODO # batch_size = batch_size * 2 @@ -35,9 +35,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e 'acorncourt.z5': (45, 50), 'zork1.z5': (55, 500), } - env_id = 'detective.z5' - # Set action_space_size and max_steps based on env_id action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found @@ -64,7 +62,14 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e reanalyze_partition: float = 0.75 # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") # ------------------------------------------------------------------ # TODO: Debug configuration - override some parameters for debugging purposes @@ -105,11 +110,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=4, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=1, # TODO: Accumulated gradient steps (currently default) model=dict( observation_shape=512, action_space_size=action_space_size, encoder_url=model_name, + encoder_option=encoder_option, model_type="mlp", continuous_action_space=False, world_model_cfg=dict( @@ -129,12 +135,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e embed_dim=embed_dim, obs_type="text", # TODO: Modify as needed. env_num=max(collector_env_num, evaluator_env_num), - decode_loss_mode='after_backbone', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + decode_loss_mode='None', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. latent_recon_loss_weight=0.1 # TODO: decoder loss weight ), ), # TODO - update_per_collect=int(collector_env_num*max_steps*replay_ratio*4 ), # Important for DDP + update_per_collect=int(collector_env_num*max_steps*replay_ratio*1), # Important for DDP action_type="varied_action_space", model_path=None, num_unroll_steps=num_unroll_steps, @@ -193,7 +199,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e main_config = lz_to_ddp_config(main_config) # Construct experiment name containing key parameters main_config.exp_name = ( - f"data_lz/data_unizero_jericho/bge-base-en-v1.5/{env_id}/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" + f"data_lz/data_unizero_jericho/{model_name}/{env_id}/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" ) diff --git a/zoo/jericho/configs/jericho_unizero_segment_config.py b/zoo/jericho/configs/jericho_unizero_segment_config.py index 6d7c4768b..a44b9cf75 100644 --- a/zoo/jericho/configs/jericho_unizero_segment_config.py +++ b/zoo/jericho/configs/jericho_unizero_segment_config.py @@ -22,7 +22,16 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: # Frequently changed configurations (user-specified) # ============================================================== # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") + + collector_env_num = 8 game_segment_length = 20 evaluator_env_num = 5 @@ -86,6 +95,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: model=dict( observation_shape=512, action_space_size=action_space_size, + encoder_option=encoder_option, encoder_url=model_name, model_type="mlp", world_model_cfg=dict( @@ -104,6 +114,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: embed_dim=embed_dim, obs_type="text", env_num=max(collector_env_num, evaluator_env_num), + decode_loss_mode='None', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + latent_recon_loss_weight=0.1 ), ), action_type="varied_action_space", From d189a737438ae53f666a46d9ba2e2d5c28e90785 Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Fri, 22 Aug 2025 09:18:01 +0000 Subject: [PATCH 4/5] fixed a few bugs and standardized the format --- lzero/mcts/buffer/game_buffer.py | 7 +++++-- lzero/model/common.py | 13 +++++++------ lzero/model/unizero_model.py | 3 --- lzero/model/unizero_world_models/tokenizer.py | 14 +++++++++++--- lzero/model/unizero_world_models/world_model.py | 4 ---- lzero/policy/unizero.py | 4 ++-- zoo/jericho/configs/jericho_unizero_ddp_config.py | 7 ++++--- .../envs/{test_qwen.py => test_qwen_prior.py} | 0 8 files changed, 29 insertions(+), 23 deletions(-) rename zoo/jericho/envs/{test_qwen.py => test_qwen_prior.py} (100%) diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index df09cebc8..f7dfb040c 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -157,18 +157,21 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), # we avoid sampling from the last `num_unroll_steps` steps of the game segment. if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps: - pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item() + if pos_in_game_segment >= len(game_segment.action_segment) - 1: + pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item() else: # For environments with a fixed action space (e.g., Atari), # we can safely sample from the entire game segment range. if pos_in_game_segment >= self._cfg.game_segment_length: pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + if pos_in_game_segment >= len(game_segment.action_segment) - 1: + pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) make_time = [time.time() for _ in range(len(batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) return orig_data diff --git a/lzero/model/common.py b/lzero/model/common.py index 731891023..7b1bbeeae 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -15,13 +15,13 @@ import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init +from transformers import AutoModelForCausalLM, AutoTokenizer from ding.torch_utils import MLP, ResBlock from ding.torch_utils.network.normalization import build_normalization from ding.utils import SequenceType from ditk import logging from ding.utils import set_pkg_seed, get_rank, get_world_size -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer + def MLP_V2( @@ -423,16 +423,17 @@ def _create_norm_layer(self, norm_option, embedding_size, group_size): def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: """ Overview: - Encode the input tensor `x` to a latent state. + Encode the input token sequence `x` into a latent representation + using a pretrained language model backbone followed by a projection head. Arguments: - - x (:obj:`torch.Tensor`): Input tensor of shape (B, C_in, W, H). + - x (:obj:`torch.Tensor`): Input token ids of shape (B, L) + - no_grad (:obj:`bool`, optional, default=True): If True, encoding is performed under `torch.no_grad()` to save memory and computation (no gradient tracking). Returns: - - latent_state (:obj:`torch.Tensor`): Encoded latent state of shape (B, embedding_dim). + - latent (:obj:`torch.Tensor`): Encoded latent state of shape (B, D). """ pad_id = self.tokenizer.pad_token_id attention_mask = (x != pad_id).long().to(x.device) context = {'input_ids': x.long(), 'attention_mask': attention_mask} - no_grad = True if no_grad: with torch.no_grad(): outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index d2d1b88b6..7a9ec84d6 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -232,9 +232,6 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torc reward = logits_rewards policy_logits = logits_policy.squeeze(1) value = logits_value.squeeze(1) - if torch.isnan(value).any() or torch.isnan(latent_state).any() or torch.isnan(policy_logits).any(): - print(f'NaN detected in value, latent_state, or policy_logits at start_pos {start_pos}') - print(f'value: {value}, latent_state: {latent_state}, policy_logits: {policy_logits}') return MZNetworkOutput( value=value, diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index 7c476e6f4..e5e18461f 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -164,24 +164,32 @@ def decode_to_reconstruction_outputs(self, embeddings: torch.Tensor, target_ids: elif self.encoder_option == 'qwen': hidden = self.projection_layer(embeddings) lm = self.decoder_network.pretrained_model - param = next(lm.parameters()) + # Get a reference parameter for device/dtype info + param = next(lm.parameters()) try: + # Retrieve the input embedding layer of the language model input_embedding_layer = lm.get_input_embeddings() except: raise ValueError('Error... Could not retrieve input embedding layer from the decoder network.') - + + # Convert target token IDs into embeddings using the LM's input embedding layer target_embeds = input_embedding_layer(target_ids) + + # Concatenate the projected hidden embeddings (prompt) with target embeddings + # hidden: (B, 1, D), target_embeds: (B, L, D) → inputs_embeds: (B, 1+L, D) inputs_embeds = torch.cat([hidden, target_embeds.detach()], dim=1) inputs_embeds = inputs_embeds.to(device=param.device, dtype=param.dtype) prompt_attention_mask = torch.ones(hidden.size(0), 1, device=param.device, dtype=torch.long) target_attention_mask = (target_ids != self.decoder_network.tokenizer.pad_token_id).to(device=param.device, dtype=torch.long) + # Concatenate prompt mask and target mask along sequence length attention_mask = torch.cat([prompt_attention_mask, target_attention_mask], dim=1) - + # Construct labels: for the prompt part, use -100 (ignored by loss function) prompt_labels = torch.full((hidden.size(0), 1), -100, device=param.device, dtype=torch.long) + # Copy target token IDs as labels, masking pad positions with -100 labels = target_ids.clone().to(param.device) labels[labels == self.decoder_network.tokenizer.pad_token_id] = -100 diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index b58f5d9fb..7f1a0f68e 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1506,9 +1506,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Compute discount coefficients for each timestep discounts = self.gamma ** timesteps - if batch['mask_padding'].sum() == 0: - assert False, "mask_padding is all zeros" - # Group losses into first step, middle step, and last step first_step_losses = {} middle_step_losses = {} @@ -1547,7 +1544,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Discount reconstruction loss and perceptual loss discounted_latent_recon_loss = latent_recon_loss discounted_perceptual_loss = perceptual_loss - # Calculate overall discounted loss discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 7ac2b070d..fdcff20fa 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -699,7 +699,7 @@ def _forward_collect( next_latent_state = next_latent_state_with_env[i][action] - if self._cfg.model.world_model_cfg.obs_type == 'text': + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': # Output the plain text content decoded by the decoder from the next latent state predicted_next = self._collect_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) else: @@ -832,7 +832,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ # Predict the next latent state based on the selected action and policy next_latent_state = next_latent_state_with_env[i][action] - if self._cfg.model.world_model_cfg.obs_type == 'text': + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': # Output the plain text content decoded by the decoder from the next latent state predicted_next = self._eval_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) else: diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index e78196996..4fefd717d 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -19,7 +19,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e gpu_num = 4 collector_env_num: int = 4 # Number of collector environments n_episode = int(collector_env_num*gpu_num) - batch_size = int(64*gpu_num) + batch_size = int(1*gpu_num) + accumulation_steps=1 # TODO # batch_size = batch_size * 2 @@ -110,7 +111,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=1, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=accumulation_steps, # TODO: Accumulated gradient steps (currently default) model=dict( observation_shape=512, action_space_size=action_space_size, @@ -140,7 +141,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), # TODO - update_per_collect=int(collector_env_num*max_steps*replay_ratio*1), # Important for DDP + update_per_collect=int(collector_env_num*max_steps*replay_ratio*accumulation_steps), # Important for DDP action_type="varied_action_space", model_path=None, num_unroll_steps=num_unroll_steps, diff --git a/zoo/jericho/envs/test_qwen.py b/zoo/jericho/envs/test_qwen_prior.py similarity index 100% rename from zoo/jericho/envs/test_qwen.py rename to zoo/jericho/envs/test_qwen_prior.py From cd811ac9f0254bfb9ee00978ef452e0491f1f823 Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Mon, 25 Aug 2025 12:56:30 +0000 Subject: [PATCH 5/5] standardize the format again --- lzero/entry/utils.py | 16 +- lzero/policy/unizero.py | 4 + zoo/jericho/envs/test_qwen_prior.py | 233 ---------------------------- 3 files changed, 16 insertions(+), 237 deletions(-) delete mode 100644 zoo/jericho/envs/test_qwen_prior.py diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 2569f6bd5..702652a83 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -111,18 +111,26 @@ def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], return torch.zeros(shape).to(device) -def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int) -> torch.Tensor: +def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int = 0) -> torch.Tensor: """ Overview: Initialize a tensor filled with `pad_token_id` for batch observations. - This is typically used to initialize input_ids with padding tokens. + This function is designed to be flexible and can handle both textual + and non-textual observations: + + - For textual observations: it initializes `input_ids` with padding tokens, + ensuring consistent sequence lengths within a batch. + - For non-textual observations: it provides a convenient way to fill + observation tensors with a default of 0, + ensuring shape compatibility and preventing uninitialized values. Arguments: - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. - batch_size (:obj:`int`): The batch size. - device (:obj:`str`): The device to store the tensor. - - pad_token_id (:obj:`int`): The token ID used for padding. + - pad_token_id (:obj:`int`): The token ID (or placeholder value) used for padding. Returns: - - padded_tensor (:obj:`torch.Tensor`): The tensor filled with pad_token_id. + - padded_tensor (:obj:`torch.Tensor`): A tensor of the given shape, + filled with `pad_token_id`. """ if isinstance(observation_shape, (list, tuple)): shape = [batch_size, *observation_shape] diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index fdcff20fa..bbf27a0e2 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -129,6 +129,10 @@ class UniZeroPolicy(MuZeroPolicy): rope_theta=10000, # (int) The maximum sequence length for position encoding. max_seq_len=8192, + # Controls where to compute reconstruction loss: 'after_backbone', 'before_backbone', or None. + # - after_backbone: The reconstruction loss is computed after the encoded representation passes through the backbone. + # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. + decode_loss_mode=None, ), ), # ****** common ****** diff --git a/zoo/jericho/envs/test_qwen_prior.py b/zoo/jericho/envs/test_qwen_prior.py deleted file mode 100644 index a3a8ea47e..000000000 --- a/zoo/jericho/envs/test_qwen_prior.py +++ /dev/null @@ -1,233 +0,0 @@ -import logging -import copy -import os -import json -from datetime import datetime -from typing import Any, Dict, List, Optional, Union -from easydict import EasyDict -from collections import deque - -import numpy as np -import torch -from transformers import AutoTokenizer - -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -import torch.distributed as dist -import torch - -from jericho_env import JerichoEnv - - -def init_distributed(): - if not dist.is_initialized(): - dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo') - rank = dist.get_rank() - world_size = dist.get_world_size() - return rank, world_size - -class Qwen3Policy: - def __init__(self, model_path=None, local_rank=0): - self.device = torch.device(f"cuda:{local_rank}") - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to(self.device) - self.max_history_len = 5 - - def format_prompt(self, history: List[Dict[str, str]], current_obs: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None) -> str: - system_prompt = ( - "You are an expert at playing text-based games.\n" - "You will be given a history of game states and actions taken, as well as your memory from previous gameplay.\n" - "Use this information to select the best next action.\n\n" - ) - if good_memory: - system_prompt += "Good memory from past games:\n" - for mem in good_memory: - system_prompt += f"- {mem}\n" - else: - system_prompt += "Good memory from past games: None\n" - - if bad_memory: - system_prompt += "Bad memory from past failures:\n" - for mem in bad_memory: - system_prompt += f"- {mem}\n" - else: - system_prompt += "Bad memory from past failures: None\n" - - - system_prompt += "\nHistory:\n" - if history is None or len(history) == 0: - system_prompt += "None\n" - for h in history: - system_prompt += f"[State]: {h['obs']}\n[Action]: {h['action']}\n" - - state = current_obs.split('Valid actions:')[0] - - system_prompt += ( - f"\nCurrent:\n[State]: {state}\n" - f"[Valid Actions]: {', '.join(valid_actions)}\n" - "Please choose the best next action from the valid actions above, and answer with only one word or phrase, without any explanation.\n" - "[Action]:" - ) - return system_prompt - - def generate_reflection(self, history: List[Dict[str, str]], positive: bool) -> str: - trajectory_str = "\n".join([f"[State]: {h['obs']}\n[Action]: {h['action']}" for h in history]) - if positive: - prompt = ( - "You will receive a log of successful gameplay from a text-based adventure game.\n" - "Summarize a good strategy or useful lesson learned from the following playthrough in one sentence.\n" - f"{trajectory_str}\n" - ) - else: - prompt = ( - "You will receive a log of unsuccessful gameplay from a text-based adventure game. Please identify the reasons for this game failure and provide a short suggestion for improving the game strategy next time. Do not summarize the gameplay trajectory; respond with your suggestion in a single sentence. For instance: 'Remember to light a lamp before entering dark areas to avoid being eaten by a grue.\n" - f"{trajectory_str}\n" - ) - - messages = [{"role": "user", "content": prompt}] - text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) - gen_config = GenerationConfig( - temperature=0.7, - top_p=0.9, - do_sample=True, - max_new_tokens=64, - pad_token_id=self.tokenizer.eos_token_id - ) - output = self.model.generate( - **model_inputs, - generation_config=gen_config - ) - output_ids = output[0][len(model_inputs.input_ids[0]):].tolist() - content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip() - return content - - def sample_action(self, history: List[Dict[str, str]], current_obs: str, valid_actions: List[str], good_memory: List[str] = None, bad_memory: List[str] = None) -> str: - prompt = self.format_prompt(history, current_obs, valid_actions, good_memory, bad_memory) - messages = [ - {"role": "user", "content": prompt} - ] - text = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False - ) - model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) - gen_config = GenerationConfig( - temperature=0.7, - top_p=0.9, - do_sample=True, - max_new_tokens=64, - pad_token_id=self.tokenizer.eos_token_id - ) - generated_ids = self.model.generate( - **model_inputs, - generation_config=gen_config - ) - output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() - content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") - - res_action = None - - for va in valid_actions: - if va.lower() in content: - res_action = va - if res_action is None: - if valid_actions: - res_action = valid_actions[0] - else: - res_action = 'go' - - return res_action, prompt - -if __name__ == '__main__': - rank, world_size = init_distributed() - print(f"[RANK {rank}] Initialized. World size: {world_size}") - - - # env_type='detective' # zork1, acorncourt, detective, omniquest - env_type='zork1' - model_name = "Qwen2.5-7B-Instruct" # Path to the Qwen model - # Configuration dictionary for the environment. - env_cfg = EasyDict( - dict( - max_steps=500, - game_path="./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/" + f"{env_type}.z5", - max_action_num=55, - tokenizer_path="google-bert/bert-base-uncased", - max_seq_len=512, - remove_stuck_actions=False, - # add_location_and_inventory=True, # TODO 尝试打开或者不打开该参数 - add_location_and_inventory=False, # TODO 尝试打开或者不打开该参数 - for_unizero=False, - collector_env_num=1, - evaluator_env_num=1, - save_replay=False, - save_replay_path=None, - env_type=env_type, - collect_policy_mode='expert' # random, human, expert - ) - ) - - if env_cfg.add_location_and_inventory: - log_dir = f'/fs-computility/niuyazhe/shared/xiongjyu/jericho/LightZero/log/{model_name}/{env_type}_add_locAndinv' - else: - log_dir = f'/fs-computility/niuyazhe/shared/xiongjyu/jericho/LightZero/log/{model_name}/{env_type}' - os.makedirs(log_dir, exist_ok=True) - - log_file = os.path.join(log_dir, f"rank_{rank}.txt") - f = open(log_file, "w", encoding="utf-8") - - env = JerichoEnv(env_cfg) - qwen_policy = Qwen3Policy(model_path=f"/fs-computility/niuyazhe/shared/xiongjyu/model/{model_name}", local_rank=rank) - history = deque(maxlen=qwen_policy.max_history_len) - - num_episodes = 20 # 可设置为任意 N - good_trial_memory = deque(maxlen=5) - bad_trial_memory = deque(maxlen=20) - - for episode_id in range(num_episodes): - f.write(f"{'='*60}\n") - f.write(f'current episode: {episode_id}\n') - f.write(f"{'='*60}\n") - f.flush() - obs = env.reset(return_str=True) - done = False - step_count = 0 - history.clear() - - while not done: - obs_str = obs['observation'] - action, prompt = qwen_policy.sample_action(history=list(history),current_obs=obs_str, valid_actions=env._action_list, - good_memory=list(good_trial_memory), bad_memory=list(bad_trial_memory)) - obs, reward, done, info = env.step(action, return_str=True) - history.append({'obs': obs_str, 'action': action}) - - # 每步写入日志 - f.write(f"Step {step_count}\n") - f.write(f"[Prompt]:\n{prompt}\n") - f.write(f"[Qwen Action]: {action}\n") - f.write(f"[Env Feedback] Reward: {reward}, Done: {done}\n") - f.write(f"{'-'*60}\n") - f.flush() - - step_count += 1 - - - reflection = qwen_policy.generate_reflection(list(history), positive=False) - bad_trial_memory.append(reflection) - f.write(f"[BAD Reflection]: {reflection}\n") - print(f'[BAD Reflection]: {reflection}') - - ## 是否生产好轨迹的relection - # reflection = qwen_policy.generate_reflection(list(history), positive=True) - # good_trial_memory.append(reflection) - # f.write(f"[GOOD Reflection]: {reflection}\n") - # print(f'[GOOD Reflection]: {reflection}') - - - f.write(f"Episode finished. Final return: {info.get('eval_episode_return', 0.0)}\n") - - f.close() - print(f"[RANK {rank}] Finished. Log written to {log_file}") - del env \ No newline at end of file