Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ We currently support a wide range of popular transformer models, including encod
- [Roberta](https://huggingface.co/FacebookAI/xlm-roberta-base): FacebookAI's `xlm-roberta-base` and its variants

### Vision Models
- 💡[**NEW**] [SmolVLM](HuggingFaceTB/SmolVLM-Instruct): `SmolVLM` and it's variants (256M, 500M, and 2B)
- 💡[**NEW**] [Gemma3 4B](https://huggingface.co/google/gemma-3-4b-it): `gemma-3-4b-it`
- [Cvt](https://huggingface.co/microsoft/cvt-13): Convolutional Vision Transformer
- [Deit](https://huggingface.co/facebook/deit-base-distilled-patch16-224): Distilled Data-efficient Image Transformer (base-sized)
- [Dit](https://huggingface.co/microsoft/dit-base-finetuned-rvlcdip): Document Image Transformer (base-sized)
Expand Down
15 changes: 13 additions & 2 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,8 +1248,19 @@ def forward(
self.encoder_name,
(multimodal_features,),
)[0]
assert len(encoder_embeddings.shape) == 3

# Get mask for where placeholder image tokens are.
encoder_token_mask = input_ids == self.encoder_token_id

# For encoders that break up one image into patches (e.g. SmolVLM), we are able to flatten them like this.
# However we are unable to do this for multi-image, which is not supported at the moment.
if encoder_embeddings.shape[0] != 1:
encoder_embeddings = encoder_embeddings.reshape(-1, encoder_embeddings.shape[-1])

# Merge in the encoder embeddings into the rest of the embeddings.
token_embeddings[encoder_token_mask] = encoder_embeddings

output = self.model.run_method(
"text_decoder",
(
Expand Down Expand Up @@ -1346,8 +1357,8 @@ def text_generation(

# Sanity check
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id:
raise ValueError(
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
logging.warning(
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} is not the same as the model's bos_token_id={self.bos_token_id}."
)
if isinstance(self.tokenizer, PreTrainedTokenizer) and not verify_eos_tokens_in_pretrained_tokenizer(
self.eos_token_id, self.tokenizer
Expand Down
18 changes: 10 additions & 8 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedModel,
StaticCache,
T5ForConditionalGeneration,
Expand All @@ -34,7 +35,7 @@

from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache

from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods
from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods


class VisionExportableModule(torch.nn.Module):
Expand All @@ -46,6 +47,7 @@ def prepare_export_inputs(self):
# 1. Get export inputs
model_id = self.model.config.name_or_path
processor = AutoProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
sample_conversation_with_image = [
{
"role": "user",
Expand All @@ -54,12 +56,10 @@ def prepare_export_inputs(self):
],
},
]
processed_inputs = processor.apply_chat_template(
processed_inputs = process_conversation_inputs(
processor,
tokenizer,
sample_conversation_with_image,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
if "pixel_values" not in processed_inputs:
raise ValueError(
Expand All @@ -76,7 +76,9 @@ def forward(
self,
input_features: torch.FloatTensor,
):
image_embeds = self.model.get_image_features(input_features)
# Pass pixel_attention_mask=None to avoid data-dependent operations during export.
# The model will create a mask full of 1s internally if None is passed.
image_embeds = self.model.get_image_features(input_features, pixel_attention_mask=None)
if isinstance(image_embeds, list):
image_embeds = torch.stack(image_embeds)
return image_embeds
Expand Down Expand Up @@ -386,7 +388,7 @@ def export(
"input_features": input_features,
},
dynamic_shapes=dynamic_shapes,
strict=True,
strict=False,
)
exported_programs[f"{self.modality}_encoder"] = encoder_exported_program

Expand Down
22 changes: 20 additions & 2 deletions optimum/exporters/executorch/tasks/multimodal_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
if not (hasattr(config, "text_config")):
raise ValueError(f"The model {model_name_or_path} does not have a `text_config`.")

config.use_export_friendly = True
config.text_config.use_export_friendly = True
if hasattr(config, "audio_config"):
config.audio_config.use_export_friendly = True
if hasattr(config, "vision_config"):
config.vision_config.use_export_friendly = True

if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
# NOTE: Avoid hitting the data-dependent control flow in _longrope_frequency_update.
config.rope_scaling["type"] = "default"
Expand Down Expand Up @@ -180,8 +187,19 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
"device": device,
},
)
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model)
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name

# Most <Model>ForConditionalGeneration> will have the text_model and encoder models as attributes, however
# some have `self.model = <Model>` (the base version not for conditional generation), and this `self.model`
# contains the text_model and encoder model attributes.
if hasattr(eager_model, "model"):
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model.model)
# Set these as top level attributes.
setattr(eager_model, decoder_name, getattr(eager_model.model, decoder_name))
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name
setattr(eager_model, encoder_name, getattr(eager_model.model, encoder_name))
else:
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model)
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name

# Need to do this since apparently when nested modules (e.g. model.language_model) access the .property
# config, it always comes from the generation_config.json file, not the `generation_config` override
Expand Down
36 changes: 30 additions & 6 deletions optimum/exporters/executorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,12 @@ def process_conversation_inputs(
input_conversation: List[Dict[str, Any]],
):
"""
Process input conversation for multimodal models.

This function handles the preprocessing of conversation inputs, with special handling for
GraniteSpeechProcessor which requires extracting and processing audio content from conversations
prior to feeding into the processor.
Process an input conversation into tensor inputs for multimodal models.

Args:
processor: The processor to use for input processing
tokenizer: The tokenizer to use for text processing
input_conversation: List of conversation messages, may contain audio content
input_conversation: List of conversation messages

Returns:
Processed inputs ready for model consumption
Expand Down Expand Up @@ -190,6 +186,34 @@ def process_conversation_inputs(
# Generate text prompt and process with audio
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, wav, return_tensors="pt")
elif isinstance(processor, transformers.SmolVLMProcessor):
from transformers.image_utils import load_image

conversation = copy.deepcopy(input_conversation)
images = []

# Extract image URLs from conversation
for message in conversation:
if isinstance(message.get("content"), list):
# Filter out image entries and collect URLs
image_urls = [item["url"] for item in message["content"] if item.get("type") == "image"]
images.extend([load_image(url) for url in image_urls])

# Remove image entries from content
message["content"] = [item for item in message["content"] if item.get("type") != "image"]

# Apply chat template to get text prompt
prompt = apply_chat_template_with_fallback(
processor,
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)

# Process with text and images
inputs = processor(text=prompt, images=images, return_tensors="pt")
else:
# Standard processing for other processors
inputs = apply_chat_template_with_fallback(
Expand Down
98 changes: 98 additions & 0 deletions tests/models/test_modeling_smolvlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import logging
import os
import sys
import unittest

import pytest
from transformers import AutoProcessor, AutoTokenizer
from transformers.testing_utils import slow

from optimum.executorch import ExecuTorchModelForMultiModalToText

from ..utils import check_multimodal_output_quality


is_linux_ci = sys.platform.startswith("linux") and os.environ.get("GITHUB_ACTIONS") == "true"


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class ExecuTorchModelIntegrationTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(is_linux_ci, reason="OOM")
def test_smolvlm_with_custom_sdpa_kv_cache_8da4w_8we(self):
model_id = "HuggingFaceTB/SmolVLM-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{"type": "text", "text": "Can you describe this image?"},
],
},
]

model = ExecuTorchModelForMultiModalToText.from_pretrained(
# model_id,
"/home/jackzhxng/models/smolvlm",
recipe="xnnpack",
task="multimodal-text-to-text",
use_custom_sdpa=True,
use_custom_kv_cache=True,
qlinear="8da4w",
qlinear_group_size=32,
# Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's
# XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221.
qembedding_config="8w",
)

# Generate
generated_text = model.text_generation(
processor=processor,
tokenizer=tokenizer,
input_conversation=conversation,
max_seq_len=64,
)
logging.info(f"\nGenerated text:\n\t{generated_text}")
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
breakpoint()

del model
del tokenizer
gc.collect()

# Should be something like: 'Okay, let's analyze this image and discuss potential
# cautions for visiting this location. Based on the picture, we're looking at a
# serene lake scene with mountains in the background, a wooden pier extending into
# the water, and a generally calm atmosphere.'
self.assertTrue("Statue" in generated_text)
self.assertTrue("Liberty" in generated_text)
self.assertTrue(
check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5)
)
Loading