diff --git a/.gitignore b/.gitignore index 8b1478d39b..a6d5b296c5 100644 --- a/.gitignore +++ b/.gitignore @@ -176,4 +176,7 @@ scripts/git-strip-merge tests/backwards_compatibility/Ref_Out # backwards compatibility -model_outputs \ No newline at end of file +model_outputs + +# TODO: remove after mllama dev +explore_mllama \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index ccad3796df..3452fa3173 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,9 @@ use_parentheses = True [flake8] ignore = E203, E501, E731, E741, W503, W605 max-line-length = 119 +per-file-ignores = + tests/test_methods/generator.py: F401, F403, F405 + tests/test_methods/test_*.py:F403,F405 [tool:pytest] doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS \ No newline at end of file diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index b8424e0107..aa4241d0c3 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -120,6 +120,7 @@ "models.llama": ["LlamaAdapterModel"], "models.mbart": ["MBartAdapterModel"], "models.mistral": ["MistralAdapterModel"], + "models.mllama": ["MllamaAdapterModel"], "models.mt5": ["MT5AdapterModel"], "models.plbart": ["PLBartAdapterModel"], "models.roberta": ["RobertaAdapterModel"], @@ -236,6 +237,7 @@ from .models.llama import LlamaAdapterModel from .models.mbart import MBartAdapterModel from .models.mistral import MistralAdapterModel + from .models.mllama import MllamaAdapterModel from .models.mt5 import MT5AdapterModel from .models.plbart import PLBartAdapterModel from .models.roberta import RobertaAdapterModel diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index b0e2aeceb8..50c8ecea2a 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -788,6 +788,16 @@ }, "layers": ["proj_out"], }, + "MllamaForConditionalGeneration": { + "config": { + "head_type": "causal_lm", + "layers": 1, + "activation_function": None, + "layer_norm": False, + "bias": False, + }, + "layers": ["language_model.lm_head"], + }, } diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 6ad4beda42..1c55f33e5a 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -176,6 +176,7 @@ def __init__( self.prefix_tunings = nn.ModuleDict() def indicate_prefix(self, prefix_name: str, location_key: str, **kwargs): + """Indicate that a Prefix Tuning module should be added to the indicated layer.""" if prefix_name not in self.prefix_counts: self.prefix_counts[prefix_name] = {location_key: {"count": 1, **kwargs}} elif location_key not in self.prefix_counts[prefix_name]: diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 77f569835d..02f5d3751a 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -20,6 +20,13 @@ from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin from .mistral.mixin_mistral import MistralModelAdapterMixin +from .mllama.mixin_mllama import ( + MllamaAdaptersMixin, + MllamaTextModelAdaptersMixin, + MllamaVisionEncoderAdaptersMixin, + MllamaVisionEncoderLayerAdaptersMixin, + MllamaVisionModelAdaptersMixin, +) from .plbart.mixin_plbart import ( PLBartDecoderAdaptersMixin, PLBartDecoderWrapperAdaptersMixin, @@ -109,4 +116,10 @@ "WhisperForAudioClassification": WhisperForAudioClassificationWithHeadsMixin, "LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin, "MistralModel": MistralModelAdapterMixin, + # Mulitmodal Llama + "MllamaModel": MllamaAdaptersMixin, + "MllamaVisionModel": MllamaVisionModelAdaptersMixin, + "MllamaTextModel": MllamaTextModelAdaptersMixin, + "MllamaVisionEncoder": MllamaVisionEncoderAdaptersMixin, + "MllamaVisionEncoderLayer": MllamaVisionEncoderLayerAdaptersMixin, } diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 6711752054..9921b1f87e 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -24,6 +24,7 @@ ("llama", "LlamaAdapterModel"), ("mbart", "MBartAdapterModel"), ("mistral", "MistralAdapterModel"), + ("mllama", "MllamaAdapterModel"), ("mt5", "MT5AdapterModel"), ("plbart", "PLBartAdapterModel"), ("roberta", "RobertaAdapterModel"), diff --git a/src/adapters/models/mllama/__init__.py b/src/adapters/models/mllama/__init__.py new file mode 100644 index 0000000000..12ff0ddd99 --- /dev/null +++ b/src/adapters/models/mllama/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["MllamaAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import MllamaAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/mllama/adapter_model.py b/src/adapters/models/mllama/adapter_model.py new file mode 100644 index 0000000000..01056080c7 --- /dev/null +++ b/src/adapters/models/mllama/adapter_model.py @@ -0,0 +1,241 @@ +import logging +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.mllama.modeling_mllama import ( + MLLAMA_START_DOCSTRING, + MllamaPreTrainedModel, + MllamaTextModel, + MllamaVisionModel, + _prepare_cross_attention_mask, +) +from transformers.utils import add_start_docstrings + +from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext +from ...heads import ModelWithFlexibleHeadsAdaptersMixin +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +logger = logging.getLogger(__name__) + + +class MllamaModel(MllamaPreTrainedModel): + """ + Base MLLaMA model that provides the fundamental architecture combining vision and text. + This serves as the foundation for the specialized adapter model version. + """ + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.vision_model = MllamaVisionModel._from_config(config.vision_config) + self.language_model = MllamaTextModel._from_config(config.text_config) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + # Establish parameter values + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Check invalid argument combinations + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + # If image is provided compute cross_attention_states + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + # Compute cross_attention_mask + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +@add_start_docstrings(MLLAMA_START_DOCSTRING) +class MllamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MllamaPreTrainedModel): + + head_types = [ + "causal_lm", + ] + + def __init__(self, config): + super().__init__(config) + + self.model = MllamaModel(config) + init(self.model) + + self._init_head_modules() + self.post_init() + + @ForwardContext.wrap + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + head=None, + **kwargs, + ): + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + attention_mask=attention_mask, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + batch_size = outputs[0].shape[0] + + if self.config.pad_token_id is None: + # TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this? + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + (sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + cls_logits = outputs[0][range(batch_size), sequence_lengths] + + outputs = self.forward_head( + outputs, + head_name=head, + cls_output=cls_logits, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + + return outputs diff --git a/src/adapters/models/mllama/mixin_mllama.py b/src/adapters/models/mllama/mixin_mllama.py new file mode 100644 index 0000000000..30a5a7462d --- /dev/null +++ b/src/adapters/models/mllama/mixin_mllama.py @@ -0,0 +1,151 @@ +from typing import Iterable, Tuple + +import torch.nn as nn + +from ...composition import adjust_tensors_for_parallel_ +from ...methods.reft import ReftLayer, hook_fn +from ...methods.prefix_tuning import PrefixTuningPool +from ...model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + InvertibleAdaptersWrapperMixin, + ModelBaseAdaptersMixin, +) +from ..clip.mixin_clip import CLIPAttentionAdaptersMixin, CLIPEncoderLayerAdaptersMixin +from ..llama.mixin_llama import LlamaDecoderLayerMixin + + +class MllamaVisionAttentionAdaptersMixin(CLIPAttentionAdaptersMixin): + """Mixin for adding adapter support to MLLaMA's vision attention module.""" + + +class MllamaTextCrossAttentionAdaptersMixin(CLIPAttentionAdaptersMixin): + """Mixin for adding adapter support to MLLaMA's cross-attention module.""" + + +class MllamaTextSelfAttentionAdaptersMixin(CLIPAttentionAdaptersMixin): + """Mixin for adding adapter support to MLLaMA's self-attention module.""" + + +class MllamaVisionEncoderLayerAdaptersMixin(CLIPEncoderLayerAdaptersMixin): + """Mixin for adding adapter support to MLLaMA's vision encoder layers.""" + + +class MllamaSelfAttentionDecoderLayerAdaptersMixin(LlamaDecoderLayerMixin): + """Mixin for adding adapter support to MLLaMA's self-attention decoder layers.""" + + +class MllamaCrossAttentionDecoderLayerAdaptersMixin(LlamaDecoderLayerMixin): + """Mixin for adding adapter support to MLLaMA's cross-attention decoder layers.""" + + +class MllamaVisionEncoderAdaptersMixin: + """Mixin for adding adapter support to MLLaMA's vision encoder module.""" + + def init_adapters(self, model_config, adapters_config): + # Set hook for parallel composition + for layer in self.layers: + self._set_layer_hook_for_parallel(layer) + + def _set_layer_hook_for_parallel(self, layer: nn.Module): + def hook(module, args, kwargs): + # Extract the hidden states from kwargs + if "hidden_state" in kwargs: + hidden_states = kwargs["hidden_state"] + attention_mask = kwargs.get("attention_mask") + if attention_mask is not None: + adjust_tensors_for_parallel_(hidden_states, attention_mask) + kwargs["hidden_state"] = hidden_states + kwargs["attention_mask"] = attention_mask + return args, kwargs + + layer.register_forward_pre_hook(hook, with_kwargs=True) + + +class MllamaVisionModelAdaptersMixin(ModelBaseAdaptersMixin): + """Adds adapters to the a MllamaVisionModel class.""" + + support_prompt_tuning = False + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + # Vision model has two encoders: + # 1. local transformer focusing on fine-grained, tile-level features + for i, layer in enumerate(self.transformer.layers): + yield i, layer + # 2. global transformer operating on output of the local transformer, integrating information across all tiles + for i, layer in enumerate(self.global_transformer.layers, start=len(self.transformer.layers)): + yield i, layer + + +class MllamaTextModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + """Adds adapters to the a MllamaTextModel class.""" + + support_prompt_tuning = False + invertible_adapters_base_name = "language_model" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.layers): + yield i, layer + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.embed_tokens.register_forward_hook(self.post_embedding_forward) + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + + +class MllamaAdaptersMixin(EmbeddingAdaptersWrapperMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): + """ + Adds adapters to the MLLaMA model, handling both vision and text components. + """ + + invertible_adapters_base_name = "language_model" + support_prompt_tuning = False + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + layer_idx = 0 + + for _, layer in self.vision_model.iter_layers(): + yield layer_idx, layer + layer_idx += 1 + + for _, layer in self.language_model.iter_layers(): + yield layer_idx, layer + layer_idx += 1 + + def _init_adapters_submodules(self, model_config, adapters_config): + """Initialize adapters in vision and language models separately.""" + # transformers naming inconsistency: Add num_attention_heads to the model config for the vision model because it is by default represented by the parameter attention_head + model_config.vision_config.num_attention_heads = model_config.vision_config.attention_heads + + # Initialize vision model adapters + for module in self.vision_model.modules(): + if hasattr(module, "init_adapters"): + module.init_adapters(model_config.vision_config, adapters_config) + + # Initialize language model adapters + for module in self.language_model.modules(): + if hasattr(module, "init_adapters"): + module.init_adapters(model_config.text_config, adapters_config) + + def _default_init_adapter_methods(self, model_config, adapters_config): + # Patch for ReFT initialization + for _, layer in self.vision_model.iter_layers(): + if not hasattr(layer, "reft_layer"): + layer.reft_layer = ReftLayer("output", model_config.vision_config, adapters_config) + layer.register_forward_hook(hook_fn) + + for _, layer in self.language_model.iter_layers(): + if not hasattr(layer, "reft_layer"): + layer.reft_layer = ReftLayer("output", model_config.text_config, adapters_config) + layer.register_forward_hook(hook_fn) + + # Add prefix tuning + self.base_model.prefix_tuning = PrefixTuningPool(model_config, adapters_config) + diff --git a/src/adapters/models/mllama/modeling_mllama.py b/src/adapters/models/mllama/modeling_mllama.py new file mode 100644 index 0000000000..4f5b3feb0c --- /dev/null +++ b/src/adapters/models/mllama/modeling_mllama.py @@ -0,0 +1,616 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from transformers.cache_utils import Cache +from transformers.models.mllama.modeling_mllama import ( + MllamaCrossAttentionDecoderLayer, + MllamaSelfAttentionDecoderLayer, + MllamaTextCrossAttention, + MllamaTextCrossSdpaAttention, + MllamaTextSelfAttention, + MllamaTextSelfSdpaAttention, + MllamaVisionAttention, + MllamaVisionSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging + +from .mixin_mllama import ( + MllamaCrossAttentionDecoderLayerAdaptersMixin, + MllamaSelfAttentionDecoderLayerAdaptersMixin, + MllamaTextCrossAttentionAdaptersMixin, + MllamaTextSelfAttentionAdaptersMixin, + MllamaVisionAttentionAdaptersMixin, +) + + +logger = logging.get_logger(__name__) + + +class MllamaVisionAttentionWithAdapters(MllamaVisionAttentionAdaptersMixin, MllamaVisionAttention): + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ) -> torch.Tensor: + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # >>> START AH Changes <<< + query, key, value = match_attn_matrices_for_parallel(query, key, value) + (attention_mask,) = adjust_tensors_for_parallel(query, attention_mask) + # >>> END AH Changes <<< + + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return output, attn_weights + + +class MllamaVisionSdpaAttentionWithAdapters(MllamaVisionAttentionAdaptersMixin, MllamaVisionSdpaAttention): + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ) -> torch.Tensor: + if output_attentions: + logger.warning_once( + "MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_state=hidden_state, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # >>> START AH Changes <<< + query, key, value = match_attn_matrices_for_parallel(query, key, value) + (attention_mask,) = adjust_tensors_for_parallel(query, attention_mask) + # >>> END AH Changes <<< + + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + return output, None + + +class MllamaTextCrossAttentionWithAdapters(MllamaTextCrossAttentionAdaptersMixin, MllamaTextCrossAttention): + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + # >>> START AH Changes <<< + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + # >>> END AH Changes <<< + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MllamaTextCrossSdpaAttentionWithAdapters(MllamaTextCrossAttentionAdaptersMixin, MllamaTextCrossSdpaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MllamaModel is using MllamaTextCrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + # >>> START AH Changes <<< + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + # >>> END AH Changes <<< + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if attention_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class MllamaTextSelfAttentionWithAdapters(MllamaTextSelfAttentionAdaptersMixin, MllamaTextSelfAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # >>> START AH Changes <<< + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + # >>> END AH Changes <<< + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # >>> START AH Changes <<< + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = key_states.shape[0] + # >>> END AH Changes <<< + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MllamaTextSelfSdpaAttentionWithAdapters(MllamaTextSelfAttentionAdaptersMixin, MllamaTextSelfSdpaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MllamaModel is using MllamaTextSelfSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # >>> START AH Changes <<< + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + # >>> END AH Changes <<< + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # >>> START AH Changes <<< + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = key_states.shape[0] + # >>> END AH Changes <<< + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value + + +class MllamaSelfAttentionDecoderLayerWithAdapters( + MllamaSelfAttentionDecoderLayerAdaptersMixin, MllamaSelfAttentionDecoderLayer +): + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + # >>> START AH Changes <<< + hidden_states = self.attention_adapters(hidden_states, residual, None) + # >>> END AH Changes <<< + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # >>> START AH Changes <<< + hidden_states = self.output_adapters(hidden_states, residual, None) + # >>> END AH Changes <<< + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MllamaCrossAttentionDecoderLayerWithAdapters( + MllamaCrossAttentionDecoderLayer, MllamaCrossAttentionDecoderLayerAdaptersMixin +): + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor]: + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + ) + # >>> START AH Changes <<< + hidden_states = self.attention_adapters(hidden_states, residual, None) + # >>> END AH Changes <<< + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # >>> START AH Changes <<< + hidden_states = self.output_adapters(hidden_states, residual, None) + # >>> END AH Changes <<< + + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs diff --git a/src/adapters/wrappers/configuration.py b/src/adapters/wrappers/configuration.py index 709bb54009..cc5e6c5195 100644 --- a/src/adapters/wrappers/configuration.py +++ b/src/adapters/wrappers/configuration.py @@ -68,6 +68,7 @@ "attention_probs_dropout_prob": "attention_dropout", }, "xlm_roberta": {}, + # TODO: add mllama } SUBMODEL_NAMES = {"clip": ["vision_config", "text_config"], "encoder-decoder": ["encoder", "decoder"]} diff --git a/tests/test_methods/test_on_mllama.py b/tests/test_methods/test_on_mllama.py new file mode 100644 index 0000000000..35ab474587 --- /dev/null +++ b/tests/test_methods/test_on_mllama.py @@ -0,0 +1,104 @@ +import os +from pathlib import Path + +import torch +from PIL import Image + +from transformers import MllamaImageProcessor +from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests + + +def from_text_vision_configs(config_class, text_config: MllamaTextConfig, vision_config: MllamaVisionConfig, **kwargs): + """ + Create a MllamaConfig instance from separate text and vision configs. + + This standalone function mimics the behavior of class methods like CLIPConfig.from_text_vision_configs, + but works without modifying the MllamaConfig class. + + Args: + config_class: The configuration class to instantiate (MllamaConfig) + text_config: The configuration for the text model + vision_config: The configuration for the vision model + **kwargs: Additional arguments to pass to the config constructor + + Returns: + An instance of config_class initialized with the text and vision configs + """ + return config_class(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +class MllamaAdapterTestBase(TextAdapterTestBase): + + config = staticmethod( + lambda: from_text_vision_configs( + MllamaConfig, + MllamaTextConfig( + vocab_size=1000, # Minimal vocab size + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=2, + num_key_value_heads=2, + intermediate_size=32, + cross_attention_layers=[3], + bos_token_id=990, + eos_token_id=991, + pad_token_id=992, + max_position_embeddings=32, + rope_scaling={ + "rope_type": "default", + }, + ), + MllamaVisionConfig( + hidden_size=32, + num_hidden_layers=4, + num_global_layers=2, + num_attention_heads=2, + intermediate_size=32, + vision_output_dim=64, + intermediate_layers_indices=[3], + ), + ) + ) + tokenizer_name = "arnavgrg/mllama-11b-vision-lora" + input_shape = (1, 128) + + # Save runtime by computing the processed image once and reusing it for all tests + FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" + + img_processor = MllamaImageProcessor(size={"height": 448, "width": 448}) + img = Image.open(os.path.join(FIXTURES_DIR, "tests_samples", "COCO", "000000039769.png")) + processed_img = img_processor(img, return_tensors="pt") + + def get_input_samples(self, vocab_size=1000, shape=(1, 128), config=None, dtype=torch.float, **kwargs): + shape = shape or self.input_shape + + # Text inputs + input_ids = self.build_rand_ids_tensor(shape, vocab_size) + + in_data = { + "input_ids": input_ids, + "pixel_values": self.processed_img["pixel_values"], + "aspect_ratio_ids": self.processed_img["aspect_ratio_ids"], + "aspect_ratio_mask": self.processed_img["aspect_ratio_mask"], + } + + if "num_labels" in kwargs: + in_data["labels"] = self.build_rand_ids_tensor(shape[:-1], vocab_size=kwargs["num_labels"]) + + return in_data + + +test_methods = generate_method_tests(MllamaAdapterTestBase) + +for test_class_name, test_class in test_methods.items(): + globals()[test_class_name] = test_class + + +""" resources: +https://github.com/AdrianBZG/llama-multimodal-vqa +https://huggingface.co/blog/llama32 +https://github.com/huggingface/huggingface-llama-recipes/blob/main/fine_tune/Llama-Vision%20FT.ipynb +""" diff --git a/tests/test_models/test_mllama_model.py b/tests/test_models/test_mllama_model.py new file mode 100644 index 0000000000..096fd4a7ae --- /dev/null +++ b/tests/test_models/test_mllama_model.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import MllamaAdapterModel +from hf_transformers.tests.models.mllama.test_modeling_mllama import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class MistralAdapterModelTest(AdapterModelTesterMixin, MllamaForConditionalGenerationIntegrationTest): + all_model_classes = (MllamaAdapterModel,) + fx_compatible = False