Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,28 @@ def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]],

return torch.zeros(shape).to(device)

def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int) -> torch.Tensor:
"""
Overview:
Initialize a tensor filled with `pad_token_id` for batch observations.
This is typically used to initialize input_ids with padding tokens.
Arguments:
- observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor.
- batch_size (:obj:`int`): The batch size.
- device (:obj:`str`): The device to store the tensor.
- pad_token_id (:obj:`int`): The token ID used for padding.
Returns:
- padded_tensor (:obj:`torch.Tensor`): The tensor filled with pad_token_id.
"""
if isinstance(observation_shape, (list, tuple)):
shape = [batch_size, *observation_shape]
elif isinstance(observation_shape, int):
shape = [batch_size, observation_shape]
else:
raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}")

return torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)

def random_collect(
policy_cfg: 'EasyDict', # noqa
policy: 'Policy', # noqa
Expand Down
2 changes: 1 addition & 1 deletion lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
# For some environments (e.g., Jericho), the action space size may be different.
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()
else:
# For environments with a fixed action space (e.g., Atari),
Expand Down
2 changes: 1 addition & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def search(
current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch,
min_max_stats_lst, results, virtual_to_play_batch
)

return first_action_latent_map


Expand Down
112 changes: 111 additions & 1 deletion lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ditk import logging
from ding.utils import set_pkg_seed, get_rank, get_world_size
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def MLP_V2(
in_channels: int,
Expand Down Expand Up @@ -361,6 +363,115 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return output

class QwenNetwork(nn.Module):
def __init__(self,
model_path: str = 'Qwen/Qwen3-1.7B',
embedding_size: int = 768,
final_norm_option_in_encoder: str = "layernorm",
group_size: int = 8,
tokenizer=None):
super().__init__()

logging.info(f"Loading Qwen model from: {model_path}")

local_rank = get_rank()
if local_rank == 0:
self.pretrained_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
device_map={"": local_rank},
attn_implementation="flash_attention_2"
)
if get_world_size() > 1:
torch.distributed.barrier()
if local_rank != 0:
self.pretrained_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
device_map={"": local_rank},
attn_implementation="flash_attention_2"
)

for p in self.pretrained_model.parameters():
p.requires_grad = False

if tokenizer is None:
if local_rank == 0:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
if get_world_size() > 1:
torch.distributed.barrier()
if local_rank != 0:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
else:
self.tokenizer = tokenizer

qwen_hidden_size = self.pretrained_model.config.hidden_size

self.embedding_head = nn.Sequential(
nn.Linear(qwen_hidden_size, embedding_size),
self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size)
)

def _create_norm_layer(self, norm_option, embedding_size, group_size):
if norm_option.lower() == "simnorm":
return SimNorm(simnorm_dim=group_size)
elif norm_option.lower() == "layernorm":
return nn.LayerNorm(embedding_size)
else:
raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.")

def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
"""
Overview:
Encode the input tensor `x` to a latent state.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor of shape (B, C_in, W, H).
Returns:
- latent_state (:obj:`torch.Tensor`): Encoded latent state of shape (B, embedding_dim).
"""
pad_id = self.tokenizer.pad_token_id
attention_mask = (x != pad_id).long().to(x.device)
context = {'input_ids': x.long(), 'attention_mask': attention_mask}
no_grad = True
if no_grad:
with torch.no_grad():
outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True)
else:
outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True)
last_hidden = outputs.hidden_states[-1]

B, L, H = last_hidden.size()
lengths = attention_mask.sum(dim=1) # [B]
positions = torch.clamp(lengths - 1, min=0) # [B]
batch_idx = torch.arange(B, device=last_hidden.device)

selected = last_hidden[batch_idx, positions] # [B, H]

latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype))
return latent

def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str:
"""
Decodes embeddings into text via the decoder network.
"""
embeddings_detached = embeddings.detach()
self.pretrained_model.eval()

# Directly generate using provided embeddings
with torch.no_grad():
param = next(self.pretrained_model.parameters())
embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype)
gen_ids = self.pretrained_model.generate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前这种方式是能正确运行,且训练后decode出的文本也是bleu很高的吗

inputs_embeds=embeddings,
max_length=max_length
)
texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
self.pretrained_model.train()
return texts[0] if len(texts) == 1 else texts

def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
return self.encode(x, no_grad=no_grad)


class HFLanguageRepresentationNetwork(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -542,7 +653,6 @@ def __init__(
else:
raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}")


def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
Expand Down
51 changes: 35 additions & 16 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
HFLanguageRepresentationNetwork
HFLanguageRepresentationNetwork, QwenNetwork
from .unizero_world_models.tokenizer import Tokenizer
from .unizero_world_models.world_model import WorldModel
from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size
Expand Down Expand Up @@ -96,21 +96,37 @@ def __init__(
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print('==' * 20)
elif world_model_cfg.obs_type == 'text':
self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder)
# print(self.representation_network.model.encoder.layer[0].attention.output.LayerNorm.weight)

if self.rank == 0:
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
if self.world_size > 1:
# Wait until rank 0 finishes loading the tokenizer
torch.distributed.barrier()
if self.rank != 0:
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")

projection = [self.representation_network.pretrained_model.config.hidden_size, self.decoder_network.config.d_model]
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, with_lpips=False, projection=projection)
if kwargs['encoder_option'] == 'legacy':
self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder)
if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none':
self.decoder_network = None
self.decoder_network_tokenizer = None
projection = None
else:
if self.rank == 0:
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
if self.world_size > 1:
# Wait until rank 0 finishes loading the tokenizer
torch.distributed.barrier()
if self.rank != 0:
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
projection = [world_model_cfg.embed_dim, self.decoder_network.config.d_model]
elif kwargs['encoder_option'] == 'qwen':
self.representation_network = QwenNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder)
if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none':
self.decoder_network = None
self.decoder_network_tokenizer = None
projection = None
else:
projection = [world_model_cfg.embed_dim, self.representation_network.pretrained_model.config.hidden_size]
self.decoder_network = self.representation_network
self.decoder_network_tokenizer = None
else:
raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}")
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer,
with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option'])
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
print('==' * 20)
Expand Down Expand Up @@ -216,6 +232,9 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torc
reward = logits_rewards
policy_logits = logits_policy.squeeze(1)
value = logits_value.squeeze(1)
if torch.isnan(value).any() or torch.isnan(latent_state).any() or torch.isnan(policy_logits).any():
print(f'NaN detected in value, latent_state, or policy_logits at start_pos {start_pos}')
print(f'value: {value}, latent_state: {latent_state}, policy_logits: {policy_logits}')

return MZNetworkOutput(
value=value,
Expand Down
Loading