-
Notifications
You must be signed in to change notification settings - Fork 180
feature(xjy): add support for selectable encoder/decoder options for jericho's world model #391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
4e2fe99
Qwen is tested as a policy in the jericho environment
d969096
fixed the bug that bad reflection cannot be collected
00d4797
supports options for selecting encoder/decoder
d189a73
fixed a few bugs and standardized the format
xiongjyu cd811ac
standardize the format again
xiongjyu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| from ditk import logging | ||
| from ding.utils import set_pkg_seed, get_rank, get_world_size | ||
| import torch | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
xiongjyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def MLP_V2( | ||
| in_channels: int, | ||
|
|
@@ -361,6 +363,115 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
|
|
||
| return output | ||
|
|
||
| class QwenNetwork(nn.Module): | ||
| def __init__(self, | ||
| model_path: str = 'Qwen/Qwen3-1.7B', | ||
| embedding_size: int = 768, | ||
| final_norm_option_in_encoder: str = "layernorm", | ||
| group_size: int = 8, | ||
| tokenizer=None): | ||
| super().__init__() | ||
|
|
||
| logging.info(f"Loading Qwen model from: {model_path}") | ||
|
|
||
| local_rank = get_rank() | ||
| if local_rank == 0: | ||
| self.pretrained_model = AutoModelForCausalLM.from_pretrained( | ||
| model_path, | ||
| torch_dtype="auto", | ||
| device_map={"": local_rank}, | ||
xiongjyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| attn_implementation="flash_attention_2" | ||
| ) | ||
| if get_world_size() > 1: | ||
| torch.distributed.barrier() | ||
| if local_rank != 0: | ||
| self.pretrained_model = AutoModelForCausalLM.from_pretrained( | ||
| model_path, | ||
| torch_dtype="auto", | ||
| device_map={"": local_rank}, | ||
| attn_implementation="flash_attention_2" | ||
| ) | ||
|
|
||
| for p in self.pretrained_model.parameters(): | ||
| p.requires_grad = False | ||
|
|
||
| if tokenizer is None: | ||
| if local_rank == 0: | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
| if get_world_size() > 1: | ||
| torch.distributed.barrier() | ||
| if local_rank != 0: | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
| else: | ||
| self.tokenizer = tokenizer | ||
|
|
||
| qwen_hidden_size = self.pretrained_model.config.hidden_size | ||
|
|
||
| self.embedding_head = nn.Sequential( | ||
| nn.Linear(qwen_hidden_size, embedding_size), | ||
| self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size) | ||
| ) | ||
|
|
||
| def _create_norm_layer(self, norm_option, embedding_size, group_size): | ||
| if norm_option.lower() == "simnorm": | ||
| return SimNorm(simnorm_dim=group_size) | ||
| elif norm_option.lower() == "layernorm": | ||
| return nn.LayerNorm(embedding_size) | ||
| else: | ||
| raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.") | ||
|
|
||
| def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: | ||
| """ | ||
| Overview: | ||
| Encode the input tensor `x` to a latent state. | ||
| Arguments: | ||
| - x (:obj:`torch.Tensor`): Input tensor of shape (B, C_in, W, H). | ||
xiongjyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Returns: | ||
| - latent_state (:obj:`torch.Tensor`): Encoded latent state of shape (B, embedding_dim). | ||
| """ | ||
| pad_id = self.tokenizer.pad_token_id | ||
| attention_mask = (x != pad_id).long().to(x.device) | ||
| context = {'input_ids': x.long(), 'attention_mask': attention_mask} | ||
| no_grad = True | ||
xiongjyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if no_grad: | ||
| with torch.no_grad(): | ||
| outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) | ||
| else: | ||
| outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) | ||
| last_hidden = outputs.hidden_states[-1] | ||
|
|
||
| B, L, H = last_hidden.size() | ||
| lengths = attention_mask.sum(dim=1) # [B] | ||
| positions = torch.clamp(lengths - 1, min=0) # [B] | ||
| batch_idx = torch.arange(B, device=last_hidden.device) | ||
|
|
||
| selected = last_hidden[batch_idx, positions] # [B, H] | ||
|
|
||
| latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype)) | ||
| return latent | ||
|
|
||
| def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str: | ||
| """ | ||
| Decodes embeddings into text via the decoder network. | ||
| """ | ||
| embeddings_detached = embeddings.detach() | ||
| self.pretrained_model.eval() | ||
|
|
||
| # Directly generate using provided embeddings | ||
| with torch.no_grad(): | ||
| param = next(self.pretrained_model.parameters()) | ||
puyuan1996 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype) | ||
| gen_ids = self.pretrained_model.generate( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前这种方式是能正确运行,且训练后decode出的文本也是bleu很高的吗 |
||
| inputs_embeds=embeddings, | ||
| max_length=max_length | ||
| ) | ||
| texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) | ||
| self.pretrained_model.train() | ||
| return texts[0] if len(texts) == 1 else texts | ||
|
|
||
| def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: | ||
| return self.encode(x, no_grad=no_grad) | ||
|
|
||
|
|
||
| class HFLanguageRepresentationNetwork(nn.Module): | ||
| def __init__(self, | ||
|
|
@@ -542,7 +653,6 @@ def __init__( | |
| else: | ||
| raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") | ||
|
|
||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Shapes: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.