Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"models.gpt2": ["GPT2AdapterModel"],
"models.gptj": ["GPTJAdapterModel"],
"models.llama": ["LlamaAdapterModel"],
"models.m2m_100": ["M2M100AdapterModel"],
"models.mbart": ["MBartAdapterModel"],
"models.mistral": ["MistralAdapterModel"],
"models.mt5": ["MT5AdapterModel"],
Expand Down Expand Up @@ -220,6 +221,7 @@
from .models.gpt2 import GPT2AdapterModel
from .models.gptj import GPTJAdapterModel
from .models.llama import LlamaAdapterModel
from .models.m2m_100 import M2M100AdapterModel
from .models.mbart import MBartAdapterModel
from .models.mistral import MistralAdapterModel
from .models.mt5 import MT5AdapterModel
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin
from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin
from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin
from .m2m_100.mixin_m2m_100 import M2M100DecoderAdaptersMixin, M2M100EncoderAdaptersMixin, M2M100ModelAdaptersMixin
from .mistral.mixin_mistral import MistralModelAdapterMixin
from .plbart.mixin_plbart import (
PLBartDecoderAdaptersMixin,
Expand Down Expand Up @@ -70,6 +71,9 @@
"MBartDecoder": BartDecoderAdaptersMixin,
"MBartDecoderWrapper": BartDecoderWrapperAdaptersMixin,
"MBartModel": BartModelAdaptersMixin,
"M2M100Model": M2M100ModelAdaptersMixin,
"M2M100Encoder": M2M100EncoderAdaptersMixin,
"M2M100Decoder": M2M100DecoderAdaptersMixin,
"MT5Block": T5BlockAdaptersMixin,
"MT5Model": T5ModelAdaptersMixin,
"MT5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin,
Expand Down
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
("gpt2", "GPT2AdapterModel"),
("gptj", "GPTJAdapterModel"),
("llama", "LlamaAdapterModel"),
("m2m_100", "M2M100AdapterModel"),
("mbart", "MBartAdapterModel"),
("mistral", "MistralAdapterModel"),
("mt5", "MT5AdapterModel"),
Expand Down
35 changes: 35 additions & 0 deletions src/adapters/models/m2m_100/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["M2M100AdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import M2M100AdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
158 changes: 158 additions & 0 deletions src/adapters/models/m2m_100/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import torch

from transformers import GenerationMixin
from transformers.models.m2m_100.modeling_m2m_100 import (
M2M_100_INPUTS_DOCSTRING,
M2M_100_START_DOCSTRING,
M2M100Config,
M2M100Model,
M2M100PreTrainedModel,
shift_tokens_right,
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init


@add_start_docstrings(
"NLLB Model with the option to add multiple flexible prediction heads on the top.", M2M_100_START_DOCSTRING
)
class M2M100AdapterModel(
EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, M2M100PreTrainedModel, GenerationMixin
):
head_types = [
"classification",
"multilabel_classification",
"question_answering",
"seq2seq_lm",
]
Comment on lines +25 to +30
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this defines the range of supported heads. Since I believe we'd only want to support sequence generation, you can remove everything except for "seq2seq_lm" here.


def __init__(self, config: M2M100Config, **kwargs):
super().__init__(config, **kwargs)
self.model = M2M100Model(config)
init(self.model)

self._init_head_modules()

self.post_init()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False

outputs, context = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

head_outputs = self.forward_head(
outputs,
head_name=head,
attention_mask=attention_mask,
return_dict=return_dict,
get_cls_from_eos_tokens=True,
# `get_cls_from_eos_tokens` requires passing eos mask
eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None,
**kwargs,
)

return head_outputs

# Copied from M2M100ForConditionalGeneration
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

# Copied from M2M100ForConditionalGeneration
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

# Copied from M2M100ForConditionalGeneration
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
116 changes: 116 additions & 0 deletions src/adapters/models/m2m_100/mixin_m2m_100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn

from ...composition import adjust_tensors_for_parallel
from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
InvertibleAdaptersWrapperMixin,
ModelBaseAdaptersMixin,
)
from ...utils import patch_forward


class M2M100AttentionAdaptersMixin:
"""Adds adapters to the M2M100Attention module."""

def init_adapters(self, model_config, adapters_config):
# Wrap layers for LoRA
self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k")
self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v")
self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q")

self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class M2M100EncoderLayerAdaptersMixin:
"""Adds adapters to the M2M100EncoderLayer module."""

def init_adapters(self, model_config, adapters_config):
self.adapters_config = adapters_config
# Wrap layers for LoRA
self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config)
self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config)

# Set attention layer location key for prefix tuning
self.self_attn.location_key = "encoder"
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class M2M100DecoderLayerAdaptersMixin(M2M100EncoderLayerAdaptersMixin):
"""Adds adapters to the M2M100DecoderLayer module."""

def init_adapters(self, model_config, adapters_config):
super().init_adapters(model_config, adapters_config)
# Set attention layer location key for prefix tuning
self.self_attn.location_key = "self"
self.encoder_attn.location_key = "cross"
self.cross_attention_adapters = BottleneckLayer("cross_adapter")


class M2M100EncoderAdaptersMixin(InvertibleAdaptersMixin):
"""Adds adapters to the M2M100Encoder module."""

pass


class M2M100DecoderAdaptersMixin:
"""Adds adapters to the M2M100Decoder module."""

def init_adapters(self, model_config, adapters_config):
patch_forward(self)

def forward(
self, input_ids: torch.LongTensor = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, **kwargs
):
(input_ids,) = adjust_tensors_for_parallel(encoder_hidden_states, input_ids)
return super().forward(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, **kwargs)


class M2M100ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the M2M100Model class."""

invertible_adapters_base_name = "encoder"
support_prompt_tuning = False

def init_adapters(self, model_config, adapters_config):
super().init_adapters(model_config, adapters_config)
self.encoder.layer_norm.register_forward_hook(self.post_embedding_forward)

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
if hasattr(self, "encoder"):
for i, layer in enumerate(self.encoder.layers):
yield i, layer
for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)):
yield i, layer
else:
for i, layer in enumerate(self.decoder.layers):
yield i, layer

def post_embedding_forward(self, module, args, embedding_output):
embedding_output = self.invertible_adapters_forward(embedding_output)
# Prompt tuning not yet supported
return embedding_output


class M2M100DecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the M2M100DecoderWrapper module."""

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.decoder.layers):
yield i, layer

def get_input_embeddings(self):
return self.decoder.get_input_embeddings()
Loading
Loading