-
Notifications
You must be signed in to change notification settings - Fork 372
Add NLLB (M2M100) support #769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vrmer
wants to merge
1
commit into
adapter-hub:main
Choose a base branch
from
vrmer:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # Copyright 2020 The Adapter-Hub Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| from transformers.utils import _LazyModule | ||
|
|
||
|
|
||
| _import_structure = { | ||
| "adapter_model": ["M2M100AdapterModel"], | ||
| } | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from .adapter_model import M2M100AdapterModel | ||
|
|
||
| else: | ||
| import sys | ||
|
|
||
| sys.modules[__name__] = _LazyModule( | ||
| __name__, | ||
| globals()["__file__"], | ||
| _import_structure, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| import torch | ||
|
|
||
| from transformers import GenerationMixin | ||
| from transformers.models.m2m_100.modeling_m2m_100 import ( | ||
| M2M_100_INPUTS_DOCSTRING, | ||
| M2M_100_START_DOCSTRING, | ||
| M2M100Config, | ||
| M2M100Model, | ||
| M2M100PreTrainedModel, | ||
| shift_tokens_right, | ||
| ) | ||
| from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward | ||
|
|
||
| from ...heads import ModelWithFlexibleHeadsAdaptersMixin | ||
| from ...model_mixin import EmbeddingAdaptersWrapperMixin | ||
| from ...wrappers import init | ||
|
|
||
|
|
||
| @add_start_docstrings( | ||
| "NLLB Model with the option to add multiple flexible prediction heads on the top.", M2M_100_START_DOCSTRING | ||
| ) | ||
| class M2M100AdapterModel( | ||
| EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, M2M100PreTrainedModel, GenerationMixin | ||
| ): | ||
| head_types = [ | ||
| "classification", | ||
| "multilabel_classification", | ||
| "question_answering", | ||
| "seq2seq_lm", | ||
| ] | ||
|
|
||
| def __init__(self, config: M2M100Config, **kwargs): | ||
| super().__init__(config, **kwargs) | ||
| self.model = M2M100Model(config) | ||
| init(self.model) | ||
|
|
||
| self._init_head_modules() | ||
|
|
||
| self.post_init() | ||
|
|
||
| def get_encoder(self): | ||
| return self.model.get_encoder() | ||
|
|
||
| def get_decoder(self): | ||
| return self.model.get_decoder() | ||
|
|
||
| @add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING) | ||
| def forward( | ||
| self, | ||
| input_ids=None, | ||
| attention_mask=None, | ||
| decoder_input_ids=None, | ||
| decoder_attention_mask=None, | ||
| head_mask=None, | ||
| decoder_head_mask=None, | ||
| cross_attn_head_mask=None, | ||
| encoder_outputs=None, | ||
| past_key_values=None, | ||
| inputs_embeds=None, | ||
| decoder_inputs_embeds=None, | ||
| use_cache=None, | ||
| output_attentions=None, | ||
| output_hidden_states=None, | ||
| return_dict=None, | ||
| head=None, | ||
| output_adapter_gating_scores=False, | ||
| output_adapter_fusion_attentions=False, | ||
| **kwargs, | ||
| ): | ||
| r""" | ||
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
| Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | ||
| config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
| """ | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: | ||
| use_cache = False | ||
|
|
||
| outputs, context = self.model( | ||
| input_ids, | ||
| attention_mask=attention_mask, | ||
| decoder_input_ids=decoder_input_ids, | ||
| decoder_attention_mask=decoder_attention_mask, | ||
| head_mask=head_mask, | ||
| decoder_head_mask=decoder_head_mask, | ||
| cross_attn_head_mask=cross_attn_head_mask, | ||
| encoder_outputs=encoder_outputs, | ||
| inputs_embeds=inputs_embeds, | ||
| decoder_inputs_embeds=decoder_inputs_embeds, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| past_key_values=past_key_values, | ||
| output_adapter_gating_scores=output_adapter_gating_scores, | ||
| output_adapter_fusion_attentions=output_adapter_fusion_attentions, | ||
| adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), | ||
| output_context=True, | ||
| ) | ||
| # required e.g. for prompt tuning in all models | ||
| kwargs["context"] = context | ||
|
|
||
| head_outputs = self.forward_head( | ||
| outputs, | ||
| head_name=head, | ||
| attention_mask=attention_mask, | ||
| return_dict=return_dict, | ||
| get_cls_from_eos_tokens=True, | ||
| # `get_cls_from_eos_tokens` requires passing eos mask | ||
| eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| return head_outputs | ||
|
|
||
| # Copied from M2M100ForConditionalGeneration | ||
| def prepare_inputs_for_generation( | ||
| self, | ||
| decoder_input_ids, | ||
| past=None, | ||
| attention_mask=None, | ||
| head_mask=None, | ||
| decoder_head_mask=None, | ||
| cross_attn_head_mask=None, | ||
| use_cache=None, | ||
| encoder_outputs=None, | ||
| **kwargs, | ||
| ): | ||
| # cut decoder_input_ids if past is used | ||
| if past is not None: | ||
| decoder_input_ids = decoder_input_ids[:, -1:] | ||
|
|
||
| return { | ||
| "input_ids": None, # encoder_outputs is defined. input_ids not needed | ||
| "encoder_outputs": encoder_outputs, | ||
| "past_key_values": past, | ||
| "decoder_input_ids": decoder_input_ids, | ||
| "attention_mask": attention_mask, | ||
| "head_mask": head_mask, | ||
| "decoder_head_mask": decoder_head_mask, | ||
| "cross_attn_head_mask": cross_attn_head_mask, | ||
| "use_cache": use_cache, # change this to avoid caching (presumably for debugging) | ||
| } | ||
|
|
||
| # Copied from M2M100ForConditionalGeneration | ||
| def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | ||
| return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) | ||
|
|
||
| # Copied from M2M100ForConditionalGeneration | ||
| @staticmethod | ||
| def _reorder_cache(past_key_values, beam_idx): | ||
| reordered_past = () | ||
| for layer_past in past_key_values: | ||
| reordered_past += ( | ||
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), | ||
| ) | ||
| return reordered_past | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| from typing import Iterable, Optional, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from ...composition import adjust_tensors_for_parallel | ||
| from ...methods.bottleneck import BottleneckLayer | ||
| from ...methods.lora import LoRALinear | ||
| from ...methods.prefix_tuning import PrefixTuningLayer | ||
| from ...model_mixin import ( | ||
| EmbeddingAdaptersMixin, | ||
| EmbeddingAdaptersWrapperMixin, | ||
| InvertibleAdaptersMixin, | ||
| InvertibleAdaptersWrapperMixin, | ||
| ModelBaseAdaptersMixin, | ||
| ) | ||
| from ...utils import patch_forward | ||
|
|
||
|
|
||
| class M2M100AttentionAdaptersMixin: | ||
| """Adds adapters to the M2M100Attention module.""" | ||
|
|
||
| def init_adapters(self, model_config, adapters_config): | ||
| # Wrap layers for LoRA | ||
| self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") | ||
| self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") | ||
| self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") | ||
|
|
||
| self.prefix_tuning = PrefixTuningLayer( | ||
| self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config | ||
| ) | ||
| patch_forward(self) | ||
|
|
||
|
|
||
| class M2M100EncoderLayerAdaptersMixin: | ||
| """Adds adapters to the M2M100EncoderLayer module.""" | ||
|
|
||
| def init_adapters(self, model_config, adapters_config): | ||
| self.adapters_config = adapters_config | ||
| # Wrap layers for LoRA | ||
| self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config) | ||
| self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config) | ||
|
|
||
| # Set attention layer location key for prefix tuning | ||
| self.self_attn.location_key = "encoder" | ||
| self.attention_adapters = BottleneckLayer("mh_adapter") | ||
| self.output_adapters = BottleneckLayer("output_adapter") | ||
|
|
||
| patch_forward(self) | ||
|
|
||
|
|
||
| class M2M100DecoderLayerAdaptersMixin(M2M100EncoderLayerAdaptersMixin): | ||
| """Adds adapters to the M2M100DecoderLayer module.""" | ||
|
|
||
| def init_adapters(self, model_config, adapters_config): | ||
| super().init_adapters(model_config, adapters_config) | ||
| # Set attention layer location key for prefix tuning | ||
| self.self_attn.location_key = "self" | ||
| self.encoder_attn.location_key = "cross" | ||
| self.cross_attention_adapters = BottleneckLayer("cross_adapter") | ||
|
|
||
|
|
||
| class M2M100EncoderAdaptersMixin(InvertibleAdaptersMixin): | ||
| """Adds adapters to the M2M100Encoder module.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class M2M100DecoderAdaptersMixin: | ||
| """Adds adapters to the M2M100Decoder module.""" | ||
|
|
||
| def init_adapters(self, model_config, adapters_config): | ||
| patch_forward(self) | ||
|
|
||
| def forward( | ||
| self, input_ids: torch.LongTensor = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, **kwargs | ||
| ): | ||
| (input_ids,) = adjust_tensors_for_parallel(encoder_hidden_states, input_ids) | ||
| return super().forward(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, **kwargs) | ||
|
|
||
|
|
||
| class M2M100ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): | ||
| """Adds adapters to the M2M100Model class.""" | ||
|
|
||
| invertible_adapters_base_name = "encoder" | ||
| support_prompt_tuning = False | ||
|
|
||
| def init_adapters(self, model_config, adapters_config): | ||
| super().init_adapters(model_config, adapters_config) | ||
| self.encoder.layer_norm.register_forward_hook(self.post_embedding_forward) | ||
|
|
||
| def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: | ||
| if hasattr(self, "encoder"): | ||
| for i, layer in enumerate(self.encoder.layers): | ||
| yield i, layer | ||
| for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)): | ||
| yield i, layer | ||
| else: | ||
| for i, layer in enumerate(self.decoder.layers): | ||
| yield i, layer | ||
|
|
||
| def post_embedding_forward(self, module, args, embedding_output): | ||
| embedding_output = self.invertible_adapters_forward(embedding_output) | ||
| # Prompt tuning not yet supported | ||
| return embedding_output | ||
|
|
||
|
|
||
| class M2M100DecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin): | ||
| """Adds adapters to the M2M100DecoderWrapper module.""" | ||
|
|
||
| def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: | ||
| for i, layer in enumerate(self.decoder.layers): | ||
| yield i, layer | ||
|
|
||
| def get_input_embeddings(self): | ||
| return self.decoder.get_input_embeddings() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this defines the range of supported heads. Since I believe we'd only want to support sequence generation, you can remove everything except for "seq2seq_lm" here.