diff --git a/README.md b/README.md index b2a5112..5281b3c 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,8 @@ We currently support a wide range of popular transformer models, including encod - [Roberta](https://huggingface.co/FacebookAI/xlm-roberta-base): FacebookAI's `xlm-roberta-base` and its variants ### Vision Models +- 💡[**NEW**] [SmolVLM](HuggingFaceTB/SmolVLM-Instruct): `SmolVLM` and it's variants (256M, 500M, and 2B) +- 💡[**NEW**] [Gemma3 4B](https://huggingface.co/google/gemma-3-4b-it): `gemma-3-4b-it` - [Cvt](https://huggingface.co/microsoft/cvt-13): Convolutional Vision Transformer - [Deit](https://huggingface.co/facebook/deit-base-distilled-patch16-224): Distilled Data-efficient Image Transformer (base-sized) - [Dit](https://huggingface.co/microsoft/dit-base-finetuned-rvlcdip): Document Image Transformer (base-sized) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 93a0034..2515c4f 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -1248,8 +1248,19 @@ def forward( self.encoder_name, (multimodal_features,), )[0] + assert len(encoder_embeddings.shape) == 3 + + # Get mask for where placeholder image tokens are. encoder_token_mask = input_ids == self.encoder_token_id + + # For encoders that break up one image into patches (e.g. SmolVLM), we are able to flatten them like this. + # However we are unable to do this for multi-image, which is not supported at the moment. + if encoder_embeddings.shape[0] != 1: + encoder_embeddings = encoder_embeddings.reshape(-1, encoder_embeddings.shape[-1]) + + # Merge in the encoder embeddings into the rest of the embeddings. token_embeddings[encoder_token_mask] = encoder_embeddings + output = self.model.run_method( "text_decoder", ( @@ -1346,8 +1357,8 @@ def text_generation( # Sanity check if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: - raise ValueError( - f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + logging.warning( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} is not the same as the model's bos_token_id={self.bos_token_id}." ) if isinstance(self.tokenizer, PreTrainedTokenizer) and not verify_eos_tokens_in_pretrained_tokenizer( self.eos_token_id, self.tokenizer diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 77c4ef0..74a084b 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -22,6 +22,7 @@ from transformers import ( AutoConfig, AutoProcessor, + AutoTokenizer, PreTrainedModel, StaticCache, T5ForConditionalGeneration, @@ -34,7 +35,7 @@ from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache -from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods +from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods class VisionExportableModule(torch.nn.Module): @@ -46,6 +47,7 @@ def prepare_export_inputs(self): # 1. Get export inputs model_id = self.model.config.name_or_path processor = AutoProcessor.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) sample_conversation_with_image = [ { "role": "user", @@ -54,12 +56,10 @@ def prepare_export_inputs(self): ], }, ] - processed_inputs = processor.apply_chat_template( + processed_inputs = process_conversation_inputs( + processor, + tokenizer, sample_conversation_with_image, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", ) if "pixel_values" not in processed_inputs: raise ValueError( @@ -76,7 +76,9 @@ def forward( self, input_features: torch.FloatTensor, ): - image_embeds = self.model.get_image_features(input_features) + # Pass pixel_attention_mask=None to avoid data-dependent operations during export. + # The model will create a mask full of 1s internally if None is passed. + image_embeds = self.model.get_image_features(input_features, pixel_attention_mask=None) if isinstance(image_embeds, list): image_embeds = torch.stack(image_embeds) return image_embeds @@ -386,7 +388,7 @@ def export( "input_features": input_features, }, dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) exported_programs[f"{self.modality}_encoder"] = encoder_exported_program diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py index 88b40a3..c2182fe 100644 --- a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -151,6 +151,13 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): if not (hasattr(config, "text_config")): raise ValueError(f"The model {model_name_or_path} does not have a `text_config`.") + config.use_export_friendly = True + config.text_config.use_export_friendly = True + if hasattr(config, "audio_config"): + config.audio_config.use_export_friendly = True + if hasattr(config, "vision_config"): + config.vision_config.use_export_friendly = True + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: # NOTE: Avoid hitting the data-dependent control flow in _longrope_frequency_update. config.rope_scaling["type"] = "default" @@ -180,8 +187,19 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): "device": device, }, ) - decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model) - encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name + + # Most ForConditionalGeneration> will have the text_model and encoder models as attributes, however + # some have `self.model = ` (the base version not for conditional generation), and this `self.model` + # contains the text_model and encoder model attributes. + if hasattr(eager_model, "model"): + decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model.model) + # Set these as top level attributes. + setattr(eager_model, decoder_name, getattr(eager_model.model, decoder_name)) + encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name + setattr(eager_model, encoder_name, getattr(eager_model.model, encoder_name)) + else: + decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model) + encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name # Need to do this since apparently when nested modules (e.g. model.language_model) access the .property # config, it always comes from the generation_config.json file, not the `generation_config` override diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index c720656..19f4916 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -139,16 +139,12 @@ def process_conversation_inputs( input_conversation: List[Dict[str, Any]], ): """ - Process input conversation for multimodal models. - - This function handles the preprocessing of conversation inputs, with special handling for - GraniteSpeechProcessor which requires extracting and processing audio content from conversations - prior to feeding into the processor. + Process an input conversation into tensor inputs for multimodal models. Args: processor: The processor to use for input processing tokenizer: The tokenizer to use for text processing - input_conversation: List of conversation messages, may contain audio content + input_conversation: List of conversation messages Returns: Processed inputs ready for model consumption @@ -190,6 +186,34 @@ def process_conversation_inputs( # Generate text prompt and process with audio prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = processor(prompt, wav, return_tensors="pt") + elif isinstance(processor, transformers.SmolVLMProcessor): + from transformers.image_utils import load_image + + conversation = copy.deepcopy(input_conversation) + images = [] + + # Extract image URLs from conversation + for message in conversation: + if isinstance(message.get("content"), list): + # Filter out image entries and collect URLs + image_urls = [item["url"] for item in message["content"] if item.get("type") == "image"] + images.extend([load_image(url) for url in image_urls]) + + # Remove image entries from content + message["content"] = [item for item in message["content"] if item.get("type") != "image"] + + # Apply chat template to get text prompt + prompt = apply_chat_template_with_fallback( + processor, + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + # Process with text and images + inputs = processor(text=prompt, images=images, return_tensors="pt") else: # Standard processing for other processors inputs = apply_chat_template_with_fallback( diff --git a/tests/models/test_modeling_smolvlm.py b/tests/models/test_modeling_smolvlm.py new file mode 100644 index 0000000..36b2f02 --- /dev/null +++ b/tests/models/test_modeling_smolvlm.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os +import sys +import unittest + +import pytest +from transformers import AutoProcessor, AutoTokenizer +from transformers.testing_utils import slow + +from optimum.executorch import ExecuTorchModelForMultiModalToText + +from ..utils import check_multimodal_output_quality + + +is_linux_ci = sys.platform.startswith("linux") and os.environ.get("GITHUB_ACTIONS") == "true" + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif(is_linux_ci, reason="OOM") + def test_smolvlm_with_custom_sdpa_kv_cache_8da4w_8we(self): + model_id = "HuggingFaceTB/SmolVLM-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id) + conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + {"type": "text", "text": "Can you describe this image?"}, + ], + }, + ] + + model = ExecuTorchModelForMultiModalToText.from_pretrained( + # model_id, + "/home/jackzhxng/models/smolvlm", + recipe="xnnpack", + task="multimodal-text-to-text", + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear="8da4w", + qlinear_group_size=32, + # Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's + # XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221. + qembedding_config="8w", + ) + + # Generate + generated_text = model.text_generation( + processor=processor, + tokenizer=tokenizer, + input_conversation=conversation, + max_seq_len=64, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + breakpoint() + + del model + del tokenizer + gc.collect() + + # Should be something like: 'Okay, let's analyze this image and discuss potential + # cautions for visiting this location. Based on the picture, we're looking at a + # serene lake scene with mountains in the background, a wooden pier extending into + # the water, and a generally calm atmosphere.' + self.assertTrue("Statue" in generated_text) + self.assertTrue("Liberty" in generated_text) + self.assertTrue( + check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5) + )