From 4b7520fa4f2f8bdf840589db169592aeec452460 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 8 Oct 2025 11:43:22 -0700 Subject: [PATCH 1/5] Manual progress --- optimum/exporters/executorch/integrations.py | 19 ++++++---- .../tasks/multimodal_text_to_text.py | 15 ++++++-- optimum/exporters/executorch/utils.py | 36 +++++++++++++++---- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 77c4ef0..fe71f07 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -22,6 +22,7 @@ from transformers import ( AutoConfig, AutoProcessor, + AutoTokenizer, PreTrainedModel, StaticCache, T5ForConditionalGeneration, @@ -34,7 +35,7 @@ from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache -from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods +from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods class VisionExportableModule(torch.nn.Module): @@ -46,6 +47,7 @@ def prepare_export_inputs(self): # 1. Get export inputs model_id = self.model.config.name_or_path processor = AutoProcessor.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) sample_conversation_with_image = [ { "role": "user", @@ -54,13 +56,18 @@ def prepare_export_inputs(self): ], }, ] - processed_inputs = processor.apply_chat_template( + processed_inputs = process_conversation_inputs( + processor, + tokenizer, sample_conversation_with_image, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", ) + # processed_inputs = processor.apply_chat_template( + # sample_conversation_with_image, + # add_generation_prompt=True, + # tokenize=True, + # return_dict=True, + # return_tensors="pt", + # ) if "pixel_values" not in processed_inputs: raise ValueError( f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'pixel_values' key: {processed_inputs}" diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py index 88b40a3..9c765f7 100644 --- a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -180,8 +180,19 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): "device": device, }, ) - decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model) - encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name + + # Most ForConditionalGeneration> will have the text_model and encoder models as attributes, however + # some have `self.model = ` (the base version not for conditional generation), and this `self.model` + # contains the text_model and encoder model attributes. + if hasattr(eager_model, "model"): + decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model.model) + # Set these as top level attributes. + setattr(eager_model, decoder_name, getattr(eager_model.model, decoder_name)) + encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name + setattr(eager_model, encoder_name, getattr(eager_model.model, encoder_name)) + else: + decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model) + encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name # Need to do this since apparently when nested modules (e.g. model.language_model) access the .property # config, it always comes from the generation_config.json file, not the `generation_config` override diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index c720656..19f4916 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -139,16 +139,12 @@ def process_conversation_inputs( input_conversation: List[Dict[str, Any]], ): """ - Process input conversation for multimodal models. - - This function handles the preprocessing of conversation inputs, with special handling for - GraniteSpeechProcessor which requires extracting and processing audio content from conversations - prior to feeding into the processor. + Process an input conversation into tensor inputs for multimodal models. Args: processor: The processor to use for input processing tokenizer: The tokenizer to use for text processing - input_conversation: List of conversation messages, may contain audio content + input_conversation: List of conversation messages Returns: Processed inputs ready for model consumption @@ -190,6 +186,34 @@ def process_conversation_inputs( # Generate text prompt and process with audio prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = processor(prompt, wav, return_tensors="pt") + elif isinstance(processor, transformers.SmolVLMProcessor): + from transformers.image_utils import load_image + + conversation = copy.deepcopy(input_conversation) + images = [] + + # Extract image URLs from conversation + for message in conversation: + if isinstance(message.get("content"), list): + # Filter out image entries and collect URLs + image_urls = [item["url"] for item in message["content"] if item.get("type") == "image"] + images.extend([load_image(url) for url in image_urls]) + + # Remove image entries from content + message["content"] = [item for item in message["content"] if item.get("type") != "image"] + + # Apply chat template to get text prompt + prompt = apply_chat_template_with_fallback( + processor, + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + # Process with text and images + inputs = processor(text=prompt, images=images, return_tensors="pt") else: # Standard processing for other processors inputs = apply_chat_template_with_fallback( From b35eccb6720d64eb3b9abadb4437e5c72e7c00cd Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 10 Oct 2025 05:54:06 -0700 Subject: [PATCH 2/5] Working, encoder not delegated to XNNPack --- optimum/executorch/modeling.py | 4 +- optimum/exporters/executorch/integrations.py | 55 +++++++++++++++++--- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 93a0034..6cfd405 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -1346,8 +1346,8 @@ def text_generation( # Sanity check if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: - raise ValueError( - f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + logging.warning( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} is not the same as the model's bos_token_id={self.bos_token_id}." ) if isinstance(self.tokenizer, PreTrainedTokenizer) and not verify_eos_tokens_in_pretrained_tokenizer( self.eos_token_id, self.tokenizer diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index fe71f07..e2dbae2 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -37,12 +37,56 @@ from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods +def _patch_idefics3_vision_embeddings_for_export(vision_model): + """ + Patch Idefics3VisionEmbeddings to make it export-friendly by removing data-dependent operations. + This assumes batch_size=1 and a full attention mask (all 1s). + """ + import types + + def export_friendly_forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + nb_patches_h = max_im_h // self.patch_size + nb_patches_w = max_im_w // self.patch_size + N = self.num_patches_per_side + + # For export, we assume full attention mask and compute position IDs statically. + # This avoids the data-dependent loop over batch dimension. + h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=torch.long) + w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=torch.long) + + # This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N), which + # we don't have a kernel for at the moment. + bucket_coords_h = (h_indices * N) // nb_patches_h + bucket_coords_w = (w_indices * N) // nb_patches_w + + bucket_coords_h = torch.clamp(bucket_coords_h, max=N - 1) + bucket_coords_w = torch.clamp(bucket_coords_w, max=N - 1) + + pos_ids = (bucket_coords_h[:, None] * N + bucket_coords_w[None, :]).reshape(-1) + position_ids = pos_ids.unsqueeze(0).expand(batch_size, -1) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + # Patch the forward method. + vision_model.embeddings.forward = types.MethodType(export_friendly_forward, vision_model.embeddings) + class VisionExportableModule(torch.nn.Module): def __init__(self, model: torch.nn.Module): super().__init__() self.model = model + # Patch Idefics3 vision embeddings if needed + if hasattr(model, 'model') and hasattr(model.model, 'vision_model'): + model_type = getattr(model.config, 'model_type', '') + if 'idefics3' in model_type.lower(): + _patch_idefics3_vision_embeddings_for_export(model.model.vision_model) + def prepare_export_inputs(self): # 1. Get export inputs model_id = self.model.config.name_or_path @@ -61,13 +105,6 @@ def prepare_export_inputs(self): tokenizer, sample_conversation_with_image, ) - # processed_inputs = processor.apply_chat_template( - # sample_conversation_with_image, - # add_generation_prompt=True, - # tokenize=True, - # return_dict=True, - # return_tensors="pt", - # ) if "pixel_values" not in processed_inputs: raise ValueError( f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'pixel_values' key: {processed_inputs}" @@ -83,7 +120,9 @@ def forward( self, input_features: torch.FloatTensor, ): - image_embeds = self.model.get_image_features(input_features) + # Pass pixel_attention_mask=None to avoid data-dependent operations during export. + # The model will create a mask full of 1s internally if None is passed. + image_embeds = self.model.get_image_features(input_features, pixel_attention_mask=None) if isinstance(image_embeds, list): image_embeds = torch.stack(image_embeds) return image_embeds From f0cae63526b2eca6d83a6129f3a82384294ce4fd Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 15 Oct 2025 09:32:40 -0700 Subject: [PATCH 3/5] Working --- optimum/executorch/modeling.py | 11 +++ optimum/exporters/executorch/integrations.py | 21 +++-- tests/models/test_modeling_smolvlm.py | 98 ++++++++++++++++++++ 3 files changed, 122 insertions(+), 8 deletions(-) create mode 100644 tests/models/test_modeling_smolvlm.py diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 6cfd405..2515c4f 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -1248,8 +1248,19 @@ def forward( self.encoder_name, (multimodal_features,), )[0] + assert len(encoder_embeddings.shape) == 3 + + # Get mask for where placeholder image tokens are. encoder_token_mask = input_ids == self.encoder_token_id + + # For encoders that break up one image into patches (e.g. SmolVLM), we are able to flatten them like this. + # However we are unable to do this for multi-image, which is not supported at the moment. + if encoder_embeddings.shape[0] != 1: + encoder_embeddings = encoder_embeddings.reshape(-1, encoder_embeddings.shape[-1]) + + # Merge in the encoder embeddings into the rest of the embeddings. token_embeddings[encoder_token_mask] = encoder_embeddings + output = self.model.run_method( "text_decoder", ( diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index e2dbae2..c5aef6f 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -37,14 +37,19 @@ from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods + def _patch_idefics3_vision_embeddings_for_export(vision_model): """ - Patch Idefics3VisionEmbeddings to make it export-friendly by removing data-dependent operations. - This assumes batch_size=1 and a full attention mask (all 1s). + Patch Idefics3VisionEmbeddings to make it: + - Export-friendly by removing data-dependent operations (forces assumption of image input + batch_size = 1, and thus a full attention mask). + - Not use aten.bucketize, which has no available decompositions or kernels in ExecuTorch. """ import types - def export_friendly_forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + def export_friendly_forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape patch_embeds = self.patch_embedding(pixel_values) @@ -81,10 +86,10 @@ def __init__(self, model: torch.nn.Module): super().__init__() self.model = model - # Patch Idefics3 vision embeddings if needed - if hasattr(model, 'model') and hasattr(model.model, 'vision_model'): - model_type = getattr(model.config, 'model_type', '') - if 'idefics3' in model_type.lower(): + # Patch Idefics3 vision embeddings to make it exportable. + if hasattr(model, "model") and hasattr(model.model, "vision_model"): + model_type = getattr(model.config, "model_type", "") + if "idefics3" in model_type.lower(): _patch_idefics3_vision_embeddings_for_export(model.model.vision_model) def prepare_export_inputs(self): @@ -432,7 +437,7 @@ def export( "input_features": input_features, }, dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) exported_programs[f"{self.modality}_encoder"] = encoder_exported_program diff --git a/tests/models/test_modeling_smolvlm.py b/tests/models/test_modeling_smolvlm.py new file mode 100644 index 0000000..36b2f02 --- /dev/null +++ b/tests/models/test_modeling_smolvlm.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os +import sys +import unittest + +import pytest +from transformers import AutoProcessor, AutoTokenizer +from transformers.testing_utils import slow + +from optimum.executorch import ExecuTorchModelForMultiModalToText + +from ..utils import check_multimodal_output_quality + + +is_linux_ci = sys.platform.startswith("linux") and os.environ.get("GITHUB_ACTIONS") == "true" + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif(is_linux_ci, reason="OOM") + def test_smolvlm_with_custom_sdpa_kv_cache_8da4w_8we(self): + model_id = "HuggingFaceTB/SmolVLM-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id) + conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + {"type": "text", "text": "Can you describe this image?"}, + ], + }, + ] + + model = ExecuTorchModelForMultiModalToText.from_pretrained( + # model_id, + "/home/jackzhxng/models/smolvlm", + recipe="xnnpack", + task="multimodal-text-to-text", + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear="8da4w", + qlinear_group_size=32, + # Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's + # XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221. + qembedding_config="8w", + ) + + # Generate + generated_text = model.text_generation( + processor=processor, + tokenizer=tokenizer, + input_conversation=conversation, + max_seq_len=64, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + breakpoint() + + del model + del tokenizer + gc.collect() + + # Should be something like: 'Okay, let's analyze this image and discuss potential + # cautions for visiting this location. Based on the picture, we're looking at a + # serene lake scene with mountains in the background, a wooden pier extending into + # the water, and a generally calm atmosphere.' + self.assertTrue("Statue" in generated_text) + self.assertTrue("Liberty" in generated_text) + self.assertTrue( + check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5) + ) From ec0079eb8e9fbd0d514bd100d141a03c8cfd8478 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 15 Oct 2025 09:43:58 -0700 Subject: [PATCH 4/5] Update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index b2a5112..5281b3c 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,8 @@ We currently support a wide range of popular transformer models, including encod - [Roberta](https://huggingface.co/FacebookAI/xlm-roberta-base): FacebookAI's `xlm-roberta-base` and its variants ### Vision Models +- 💡[**NEW**] [SmolVLM](HuggingFaceTB/SmolVLM-Instruct): `SmolVLM` and it's variants (256M, 500M, and 2B) +- 💡[**NEW**] [Gemma3 4B](https://huggingface.co/google/gemma-3-4b-it): `gemma-3-4b-it` - [Cvt](https://huggingface.co/microsoft/cvt-13): Convolutional Vision Transformer - [Deit](https://huggingface.co/facebook/deit-base-distilled-patch16-224): Distilled Data-efficient Image Transformer (base-sized) - [Dit](https://huggingface.co/microsoft/dit-base-finetuned-rvlcdip): Document Image Transformer (base-sized) From f847a8f38c0e683a6adf6989988f4f22a724951c Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 15 Oct 2025 15:23:15 -0700 Subject: [PATCH 5/5] Uses transformers use_export_friendly --- optimum/exporters/executorch/integrations.py | 49 ------------------- .../tasks/multimodal_text_to_text.py | 7 +++ 2 files changed, 7 insertions(+), 49 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index c5aef6f..74a084b 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -38,60 +38,11 @@ from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods -def _patch_idefics3_vision_embeddings_for_export(vision_model): - """ - Patch Idefics3VisionEmbeddings to make it: - - Export-friendly by removing data-dependent operations (forces assumption of image input - batch_size = 1, and thus a full attention mask). - - Not use aten.bucketize, which has no available decompositions or kernels in ExecuTorch. - """ - import types - - def export_friendly_forward( - self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor - ) -> torch.Tensor: - batch_size, _, max_im_h, max_im_w = pixel_values.shape - - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - nb_patches_h = max_im_h // self.patch_size - nb_patches_w = max_im_w // self.patch_size - N = self.num_patches_per_side - - # For export, we assume full attention mask and compute position IDs statically. - # This avoids the data-dependent loop over batch dimension. - h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=torch.long) - w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=torch.long) - - # This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N), which - # we don't have a kernel for at the moment. - bucket_coords_h = (h_indices * N) // nb_patches_h - bucket_coords_w = (w_indices * N) // nb_patches_w - - bucket_coords_h = torch.clamp(bucket_coords_h, max=N - 1) - bucket_coords_w = torch.clamp(bucket_coords_w, max=N - 1) - - pos_ids = (bucket_coords_h[:, None] * N + bucket_coords_w[None, :]).reshape(-1) - position_ids = pos_ids.unsqueeze(0).expand(batch_size, -1) - embeddings = embeddings + self.position_embedding(position_ids) - return embeddings - - # Patch the forward method. - vision_model.embeddings.forward = types.MethodType(export_friendly_forward, vision_model.embeddings) - - class VisionExportableModule(torch.nn.Module): def __init__(self, model: torch.nn.Module): super().__init__() self.model = model - # Patch Idefics3 vision embeddings to make it exportable. - if hasattr(model, "model") and hasattr(model.model, "vision_model"): - model_type = getattr(model.config, "model_type", "") - if "idefics3" in model_type.lower(): - _patch_idefics3_vision_embeddings_for_export(model.model.vision_model) - def prepare_export_inputs(self): # 1. Get export inputs model_id = self.model.config.name_or_path diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py index 9c765f7..c2182fe 100644 --- a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -151,6 +151,13 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): if not (hasattr(config, "text_config")): raise ValueError(f"The model {model_name_or_path} does not have a `text_config`.") + config.use_export_friendly = True + config.text_config.use_export_friendly = True + if hasattr(config, "audio_config"): + config.audio_config.use_export_friendly = True + if hasattr(config, "vision_config"): + config.vision_config.use_export_friendly = True + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: # NOTE: Avoid hitting the data-dependent control flow in _longrope_frequency_update. config.rope_scaling["type"] = "default"