diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index a917828e72..081ad22610 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -109,6 +109,7 @@ "models.gpt2": ["GPT2AdapterModel"], "models.gptj": ["GPTJAdapterModel"], "models.llama": ["LlamaAdapterModel"], + "models.m2m_100": ["M2M100AdapterModel"], "models.mbart": ["MBartAdapterModel"], "models.mistral": ["MistralAdapterModel"], "models.mt5": ["MT5AdapterModel"], @@ -220,6 +221,7 @@ from .models.gpt2 import GPT2AdapterModel from .models.gptj import GPTJAdapterModel from .models.llama import LlamaAdapterModel + from .models.m2m_100 import M2M100AdapterModel from .models.mbart import MBartAdapterModel from .models.mistral import MistralAdapterModel from .models.mt5 import MT5AdapterModel diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 77f569835d..2985a159de 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -19,6 +19,7 @@ from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin +from .m2m_100.mixin_m2m_100 import M2M100DecoderAdaptersMixin, M2M100EncoderAdaptersMixin, M2M100ModelAdaptersMixin from .mistral.mixin_mistral import MistralModelAdapterMixin from .plbart.mixin_plbart import ( PLBartDecoderAdaptersMixin, @@ -70,6 +71,9 @@ "MBartDecoder": BartDecoderAdaptersMixin, "MBartDecoderWrapper": BartDecoderWrapperAdaptersMixin, "MBartModel": BartModelAdaptersMixin, + "M2M100Model": M2M100ModelAdaptersMixin, + "M2M100Encoder": M2M100EncoderAdaptersMixin, + "M2M100Decoder": M2M100DecoderAdaptersMixin, "MT5Block": T5BlockAdaptersMixin, "MT5Model": T5ModelAdaptersMixin, "MT5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin, diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 6711752054..dd046a1aec 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -22,6 +22,7 @@ ("gpt2", "GPT2AdapterModel"), ("gptj", "GPTJAdapterModel"), ("llama", "LlamaAdapterModel"), + ("m2m_100", "M2M100AdapterModel"), ("mbart", "MBartAdapterModel"), ("mistral", "MistralAdapterModel"), ("mt5", "MT5AdapterModel"), diff --git a/src/adapters/models/m2m_100/__init__.py b/src/adapters/models/m2m_100/__init__.py new file mode 100644 index 0000000000..5322e7cd6c --- /dev/null +++ b/src/adapters/models/m2m_100/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["M2M100AdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import M2M100AdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/m2m_100/adapter_model.py b/src/adapters/models/m2m_100/adapter_model.py new file mode 100644 index 0000000000..c250dfe377 --- /dev/null +++ b/src/adapters/models/m2m_100/adapter_model.py @@ -0,0 +1,158 @@ +import torch + +from transformers import GenerationMixin +from transformers.models.m2m_100.modeling_m2m_100 import ( + M2M_100_INPUTS_DOCSTRING, + M2M_100_START_DOCSTRING, + M2M100Config, + M2M100Model, + M2M100PreTrainedModel, + shift_tokens_right, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward + +from ...heads import ModelWithFlexibleHeadsAdaptersMixin +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +@add_start_docstrings( + "NLLB Model with the option to add multiple flexible prediction heads on the top.", M2M_100_START_DOCSTRING +) +class M2M100AdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, M2M100PreTrainedModel, GenerationMixin +): + head_types = [ + "classification", + "multilabel_classification", + "question_answering", + "seq2seq_lm", + ] + + def __init__(self, config: M2M100Config, **kwargs): + super().__init__(config, **kwargs) + self.model = M2M100Model(config) + init(self.model) + + self._init_head_modules() + + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: + use_cache = False + + outputs, context = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, + ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context + + head_outputs = self.forward_head( + outputs, + head_name=head, + attention_mask=attention_mask, + return_dict=return_dict, + get_cls_from_eos_tokens=True, + # `get_cls_from_eos_tokens` requires passing eos mask + eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None, + **kwargs, + ) + + return head_outputs + + # Copied from M2M100ForConditionalGeneration + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + # Copied from M2M100ForConditionalGeneration + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + # Copied from M2M100ForConditionalGeneration + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past diff --git a/src/adapters/models/m2m_100/mixin_m2m_100.py b/src/adapters/models/m2m_100/mixin_m2m_100.py new file mode 100644 index 0000000000..92c8602612 --- /dev/null +++ b/src/adapters/models/m2m_100/mixin_m2m_100.py @@ -0,0 +1,116 @@ +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn + +from ...composition import adjust_tensors_for_parallel +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer +from ...model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + InvertibleAdaptersWrapperMixin, + ModelBaseAdaptersMixin, +) +from ...utils import patch_forward + + +class M2M100AttentionAdaptersMixin: + """Adds adapters to the M2M100Attention module.""" + + def init_adapters(self, model_config, adapters_config): + # Wrap layers for LoRA + self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") + self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") + self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") + + self.prefix_tuning = PrefixTuningLayer( + self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config + ) + patch_forward(self) + + +class M2M100EncoderLayerAdaptersMixin: + """Adds adapters to the M2M100EncoderLayer module.""" + + def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config + # Wrap layers for LoRA + self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config) + self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config) + + # Set attention layer location key for prefix tuning + self.self_attn.location_key = "encoder" + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") + + patch_forward(self) + + +class M2M100DecoderLayerAdaptersMixin(M2M100EncoderLayerAdaptersMixin): + """Adds adapters to the M2M100DecoderLayer module.""" + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + # Set attention layer location key for prefix tuning + self.self_attn.location_key = "self" + self.encoder_attn.location_key = "cross" + self.cross_attention_adapters = BottleneckLayer("cross_adapter") + + +class M2M100EncoderAdaptersMixin(InvertibleAdaptersMixin): + """Adds adapters to the M2M100Encoder module.""" + + pass + + +class M2M100DecoderAdaptersMixin: + """Adds adapters to the M2M100Decoder module.""" + + def init_adapters(self, model_config, adapters_config): + patch_forward(self) + + def forward( + self, input_ids: torch.LongTensor = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, **kwargs + ): + (input_ids,) = adjust_tensors_for_parallel(encoder_hidden_states, input_ids) + return super().forward(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, **kwargs) + + +class M2M100ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): + """Adds adapters to the M2M100Model class.""" + + invertible_adapters_base_name = "encoder" + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + self.encoder.layer_norm.register_forward_hook(self.post_embedding_forward) + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + if hasattr(self, "encoder"): + for i, layer in enumerate(self.encoder.layers): + yield i, layer + for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)): + yield i, layer + else: + for i, layer in enumerate(self.decoder.layers): + yield i, layer + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + + +class M2M100DecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin): + """Adds adapters to the M2M100DecoderWrapper module.""" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.decoder.layers): + yield i, layer + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() diff --git a/src/adapters/models/m2m_100/modeling_m2m_100.py b/src/adapters/models/m2m_100/modeling_m2m_100.py new file mode 100644 index 0000000000..b1a8d04326 --- /dev/null +++ b/src/adapters/models/m2m_100/modeling_m2m_100.py @@ -0,0 +1,567 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.models.m2m_100.modeling_m2m_100 import ( + M2M100Attention, + M2M100DecoderLayer, + M2M100EncoderLayer, + M2M100FlashAttention2, + M2M100SdpaAttention, +) +from transformers.utils import logging + +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel +from .mixin_m2m_100 import ( + M2M100AttentionAdaptersMixin, + M2M100DecoderLayerAdaptersMixin, + M2M100EncoderLayerAdaptersMixin, +) + + +logger = logging.get_logger(__name__) + + +class M2M100AttentionWithAdapters(M2M100AttentionAdaptersMixin, M2M100Attention): + """Multi-headed attention from 'Attention Is All You Need'.""" + + # Loosen constraint on batch_size to allow parallel adapter composition + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value states are provided, this layer is used as a cross-attention-layer for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class M2M100FlashAttention2WithAdapters(M2M100AttentionAdaptersMixin, M2M100FlashAttention2): + + # Loosen constraint on batch_size to allow parallel adapter composition + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # M2M100FlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("M2M100FlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided, this layer is used as cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attention_use_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class M2M100SdpaAttentionWithAdapters(M2M100AttentionAdaptersMixin, M2M100SdpaAttention): + + # Loosen constraint on batch_size to allow parallel adapter composition + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch X Time X Channel""" + if output_attentions or layer_head_mask is not None: + logger.warning_once( + "M2M100Model is using M2M100SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if the key_value states are provided, this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +class M2M100EncoderLayerWithAdapters(M2M100EncoderLayerAdaptersMixin, M2M100EncoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: bool = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + adjust_tensors_for_parallel(hidden_states, attention_mask) + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.attention_adapters(hidden_states, residual, None) + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.output_adapters(hidden_states, residual, None) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class M2M100DecoderLayerWithAdapters(M2M100DecoderLayerAdaptersMixin, M2M100DecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + adjust_tensors_for_parallel_(hidden_states, attention_mask, encoder_attention_mask) + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.attention_adapters(hidden_states, residual, None) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.cross_attention_adapters(hidden_states, residual, None) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.output_adapters(hidden_states, residual, None) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/src/adapters/wrappers/configuration.py b/src/adapters/wrappers/configuration.py index 40dc421787..5d61e10af2 100644 --- a/src/adapters/wrappers/configuration.py +++ b/src/adapters/wrappers/configuration.py @@ -46,6 +46,12 @@ "hidden_dropout_prob": "dropout", "attention_probs_dropout_prob": "attention_dropout", }, + "m2m_100": { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model", + "hidden_dropout_prob": "dropout", + "attention_probs_dropout_prob": "attention_dropout", + }, "plbart": { "num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model", diff --git a/tests/models/test_m2m_100.py b/tests/models/test_m2m_100.py new file mode 100644 index 0000000000..b81b52876b --- /dev/null +++ b/tests/models/test_m2m_100.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import M2M100AdapterModel +from hf_transformers.tests.models.m2m_100.test_modeling_m2m_100 import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class M2M100AdapterModelTest(AdapterModelTesterMixin, M2M100ModelTest): + all_model_classes = (M2M100AdapterModel,) + fx_compatible = False diff --git a/tests/test_m2m_100.py b/tests/test_m2m_100.py new file mode 100644 index 0000000000..a4d4934c31 --- /dev/null +++ b/tests/test_m2m_100.py @@ -0,0 +1,70 @@ +import unittest + +from tests.methods.test_config_union import ConfigUnionAdapterTest +from transformers import M2M100Config +from transformers.testing_utils import require_torch + +from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + ReftTestMixin, + UniPELTTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class M2M100AdapterTestBase(AdapterTestBase): + config_class = M2M100Config + config = make_config( + M2M100Config, + vocab_size=256206, + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + scale_embedding=False, + ) + tokenizer_name = "facebook/nllb-200-distilled-600M" + + +@require_torch +class M2M100AdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + ReftTestMixin, + UniPELTTestMixin, + AdapterFusionModelTestMixin, + CompabilityTestMixin, + EmbeddingTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + ConfigUnionAdapterTest, + M2M100AdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class M2M100ClassConversionTest( + ModelClassConversionTestMixin, + M2M100AdapterTestBase, + unittest.TestCase, +): + pass