diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 23e6819a..27ef8425 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -496,3 +496,146 @@ def generate(self, prompt_token_ids, max_new_tokens): break return generated_ids + + +class ImageTextToTextExportableModule(torch.nn.Module): + """ + A wrapper module designed to make an image-text-to-text model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False): + super().__init__() + self.model = model + self.config = model.config + self.use_custom_kv_cache = use_custom_kv_cache + self.use_custom_sdpa = use_custom_sdpa + from .utils import save_config_to_constant_methods + self.metadata = save_config_to_constant_methods( + model.config.text_config, model.generation_config + ) + logging.info(f"Metadata to be recorded in PTE: {self.metadata}") + + def _prepare_vision_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + pixel_values (torch.Tensor): Example pixel values tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + image_size = self.config.vision_config.image_size + pixel_values = torch.rand((1, 3, image_size, image_size)) + dynamic_shapes = None + strict = False + + return pixel_values, dynamic_shapes, strict + + def _prepare_text_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + inputs_embeds (torch.Tensor): Example inputs embeddings tensor. + cache_position (torch.Tensor): Example cache position tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + # Prepare inputs with dynamic shapes + seq_length = 3 # Sequence length > 1 to avoid specialization issues + hidden_size = self.config.text_config.hidden_size + example_inputs_embeds = torch.zeros((1, seq_length, hidden_size), dtype=torch.float32) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + strict = False + + return example_inputs_embeds, example_cache_position, dynamic_shapes, strict + + def export( + self, + ) -> Dict[str, ExportedProgram]: + """ + Export both the vision encoder and text decoder components. + + Returns: + Dict[str, ExportedProgram]: Dictionary containing exported programs for vision and text components. + """ + # Export vision encoder + pixel_values, vision_dynamic_shapes, vision_strict = self._prepare_vision_embedding_export_inputs() + logging.info( + f"Exporting vision encoder using pixel_values({pixel_values.shape}), dynamic_shapes={vision_dynamic_shapes}, strict={vision_strict}" + ) + + # Create vision encoder wrapper + vision_encoder = VisionEncoderExportableModule(self.model) + with torch.no_grad(): + vision_exported_program = vision_encoder.export(pixel_values)["model"] + + # Export text decoder + inputs_embeds, cache_position, text_dynamic_shapes, text_strict = self._prepare_text_embedding_export_inputs() + logging.info( + f"Exporting text decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape}), dynamic_shapes={text_dynamic_shapes}, strict={text_strict}" + ) + + # Use the enhanced transformers integration for multimodal support + if is_transformers_version(">", "4.52.0"): + from transformers.integrations.executorch import ( + TorchExportableModuleForImageTextLM, + ) + + exportable_module = TorchExportableModuleForImageTextLM( + self.model.language_model, + max_batch_size=1, + max_cache_len=self.metadata.get("get_max_seq_len"), + ) + self._register_attention_mask_for_4_53(exportable_module) + + if self.use_custom_kv_cache: + from optimum.executorch.attentions.custom_kv_cache import ( + replace_with_et_custom_kv_cache, + ) + + replace_with_et_custom_kv_cache( + exportable_module.model, + self.model.language_model.config, + self.model.generation_config, + self.model.dtype, + ) + + with torch.no_grad(): + text_exported_program = exportable_module.export(inputs_embeds, cache_position, text_dynamic_shapes, text_strict) + else: + raise ValueError("Image-text-to-text export requires transformers > 4.52.0") + + return { + "vision_encoder": vision_exported_program, + "text_decoder": text_exported_program + } + + def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module): + """Register attention mask for transformers >= 4.53.0""" + if is_transformers_version(">=", "4.53.0.dev0"): + from transformers.integrations.executorch import sdpa_mask_without_vmap + from transformers.masking_utils import AttentionMaskInterface + from transformers.modeling_utils import AttentionInterface + + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) + if self.use_custom_sdpa: + if self.use_custom_kv_cache: + AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) + AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" + else: + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = "custom_sdpa" diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py index 0f7c3be3..6c09eb19 100644 --- a/optimum/exporters/executorch/tasks/__init__.py +++ b/optimum/exporters/executorch/tasks/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import causal_lm, image_classification, masked_lm, seq2seq_lm +from . import causal_lm, image_classification, image_text_to_text, masked_lm, seq2seq_lm diff --git a/optimum/exporters/executorch/tasks/image_text_to_text.py b/optimum/exporters/executorch/tasks/image_text_to_text.py new file mode 100644 index 00000000..918b8c47 --- /dev/null +++ b/optimum/exporters/executorch/tasks/image_text_to_text.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torchao +from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig + +from ..integrations import ImageTextToTextExportableModule +from ..quantization import quantize_model_ +from ..task_registry import register_task + + +# NOTE: It's important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +@register_task("image-text-to-text") +def load_image_text_to_text_model(model_name_or_path: str, **kwargs) -> ImageTextToTextExportableModule: + """ + Loads a causal language model for image-to-text generation and registers it under the task + 'image-text-to-text' using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="google/gemma-3-4b-it"` or `model_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + ImageTextToTextExportableModule: + An instance of `ImageTextToTextExportableModule` for exporting and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + use_custom_sdpa = kwargs.get("use_custom_sdpa", False) + use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False) + attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa" + max_length = kwargs.get("max_length", 2048) + config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path) + + # Make sure config has text_config and vision_config: + if not hasattr(config, "text_config") or not hasattr(config, "vision_config"): + raise ValueError( + f"The model {model_name_or_path} does not have a `text_config` or `vision_config` attribute in its config. " + "This is required for image-text-to-text models." + ) + + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting + # the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite + # that function to avoid the data-dependent control flow. + config.rope_scaling["type"] = "default" + + if hasattr(config, "use_cache") and config.use_cache is False: + config.use_cache = True + + def _load_eager_pretrained( + model_name_or_path, + device, + dtype, + config, + attn_implementation, + cache_implementation, + batch_size, + max_length, + ): + eager_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + config=config, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) + return eager_model + + try: + eager_model = _load_eager_pretrained( + model_name_or_path, + device, + dtype, + config, + attn_implementation, + cache_implementation, + batch_size, + max_length, + ) + except ValueError as e: + if "torch.nn.functional.scaled_dot_product_attention" in str(e): + logging.info("⚠ SDPA attention not supported, falling back to eager implementation") + attn_implementation = "eager" + eager_model = _load_eager_pretrained( + model_name_or_path, + device, + dtype, + config, + attn_implementation, + cache_implementation, + batch_size, + max_length, + ) + else: + raise + + # Make sure model has language_model as well as vision_tower: + if not hasattr(eager_model, "language_model") or not hasattr(eager_model, "vision_tower"): + raise ValueError( + f"The model {model_name_or_path} does not have a `language_model` or `vision_tower` attribute. " + "This is required for image-text-to-text models." + ) + + for param in eager_model.parameters(): + # Must disable gradient for quantized checkpoint + if isinstance(param, torchao.utils.TorchAOBaseTensor): + param.requires_grad = False + + qlinear_config = kwargs.get("qlinear", None) + qembedding_config = kwargs.get("qembedding", None) + quantize_model_(eager_model, qlinear_config=qlinear_config, qembedding_config=qembedding_config) + + return ImageTextToTextExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa) diff --git a/tests/models/test_modeling_image_text_to_text.py b/tests/models/test_modeling_image_text_to_text.py new file mode 100644 index 00000000..5fd3e206 --- /dev/null +++ b/tests/models/test_modeling_image_text_to_text.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os +import subprocess +import sys +import tempfile +import unittest +from unittest.mock import Mock, patch + +import pytest +import torch +from transformers.testing_utils import slow + +from optimum.executorch import ExecuTorchModelForCausalLM +from optimum.exporters.executorch.integrations import ImageTextToTextExportableModule +from optimum.utils.import_utils import is_transformers_version + + +is_linux_ci = sys.platform.startswith("linux") and os.environ.get("GITHUB_ACTIONS") == "true" + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@pytest.mark.skipif( + is_transformers_version("<", "4.52.0.dev0"), + reason="Only available on transformers >= 4.52.0.dev0", +) +class ImageTextToTextExportTest(unittest.TestCase): + def setUp(self): + # Mock multimodal model configuration + self.mock_config = Mock() + self.mock_config.text_config = Mock() + self.mock_config.text_config.use_cache = True + self.mock_config.text_config.hidden_size = 768 + self.mock_config.text_config.num_hidden_layers = 12 + self.mock_config.vision_config = Mock() + self.mock_config.vision_config.image_size = 224 + + # Mock generation config + self.mock_generation_config = Mock() + self.mock_generation_config.use_cache = True + self.mock_generation_config.cache_implementation = "static" + self.mock_generation_config.max_length = 2048 + self.mock_generation_config.cache_config = { + "batch_size": 1, + "max_cache_len": 2048, + } + + # Mock model + self.mock_model = Mock() + self.mock_model.config = self.mock_config + self.mock_model.generation_config = self.mock_generation_config + self.mock_model.device = torch.device("cpu") + self.mock_model.dtype = torch.float32 + + # Mock language model and vision tower + self.mock_model.language_model = Mock() + self.mock_model.vision_tower = Mock() + + def test_image_text_to_text_module_initialization(self): + """Test that ImageTextToTextExportableModule initializes correctly""" + with patch("optimum.exporters.executorch.integrations.save_config_to_constant_methods") as mock_save: + mock_save.return_value = {"get_max_seq_len": 2048} + + module = ImageTextToTextExportableModule(self.mock_model) + + self.assertEqual(module.model, self.mock_model) + self.assertEqual(module.config, self.mock_config) + self.assertFalse(module.use_custom_kv_cache) + self.assertFalse(module.use_custom_sdpa) + mock_save.assert_called_once_with(self.mock_config.text_config, self.mock_generation_config) + + def test_vision_embedding_export_inputs_preparation(self): + """Test vision embedding export inputs preparation""" + with patch("optimum.exporters.executorch.integrations.save_config_to_constant_methods") as mock_save: + mock_save.return_value = {"get_max_seq_len": 2048} + + module = ImageTextToTextExportableModule(self.mock_model) + pixel_values, dynamic_shapes, strict = module._prepare_vision_embedding_export_inputs() + + self.assertEqual(pixel_values.shape, (1, 3, 224, 224)) # batch, channels, height, width + self.assertIsNone(dynamic_shapes) + self.assertFalse(strict) + + def test_text_embedding_export_inputs_preparation(self): + """Test text embedding export inputs preparation""" + with patch("optimum.exporters.executorch.integrations.save_config_to_constant_methods") as mock_save: + mock_save.return_value = {"get_max_seq_len": 2048, "sliding_window": float("inf")} + + module = ImageTextToTextExportableModule(self.mock_model) + inputs_embeds, cache_position, dynamic_shapes, strict = module._prepare_text_embedding_export_inputs() + + self.assertEqual(inputs_embeds.shape, (1, 3, 768)) # batch, seq_len, hidden_size + self.assertEqual(cache_position.shape, (3,)) # seq_len + self.assertIn("inputs_embeds", dynamic_shapes) + self.assertIn("cache_position", dynamic_shapes) + self.assertFalse(strict) + + def test_export_method_structure(self): + """Test that export method has correct structure""" + with patch("optimum.exporters.executorch.integrations.save_config_to_constant_methods") as mock_save: + with patch("optimum.exporters.executorch.integrations.VisionEncoderExportableModule") as mock_vision: + with patch("optimum.exporters.executorch.integrations.is_transformers_version") as mock_version: + mock_save.return_value = {"get_max_seq_len": 2048, "sliding_window": float("inf")} + mock_version.return_value = True + + # Mock vision encoder export + mock_vision_instance = Mock() + mock_vision_instance.export.return_value = {"model": Mock()} + mock_vision.return_value = mock_vision_instance + + # Mock transformers module + with patch("transformers.integrations.executorch.TorchExportableModuleForImageTextLM") as mock_text_module: + mock_text_instance = Mock() + mock_text_instance.export.return_value = Mock() + mock_text_module.return_value = mock_text_instance + + module = ImageTextToTextExportableModule(self.mock_model) + result = module.export() + + # Verify structure + self.assertIn("vision_encoder", result) + self.assertIn("text_decoder", result) + + # Verify calls + mock_vision.assert_called_once_with(self.mock_model) + mock_text_module.assert_called_once() + + def test_validation_errors(self): + """Test validation errors for invalid configurations""" + # Test missing text_config + bad_config = Mock() + bad_config.vision_config = Mock() + # Missing text_config + + bad_model = Mock() + bad_model.config = bad_config + + with patch("optimum.exporters.executorch.tasks.image_text_to_text.AutoConfig") as mock_auto_config: + mock_auto_config.from_pretrained.return_value = bad_config + + from optimum.exporters.executorch.tasks.image_text_to_text import load_image_text_to_text_model + + with self.assertRaises(ValueError) as context: + load_image_text_to_text_model("test_model") + + self.assertIn("text_config", str(context.exception)) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif(is_linux_ci, reason="OOM on linux runner") + def test_cli_export_integration(self): + """Test CLI integration for image-text-to-text task""" + # This would test the actual CLI command but requires a real model + # For now, just test that the task is registered correctly + from optimum.exporters.executorch.task_registry import task_registry + + # Discover tasks to populate registry + from optimum.exporters.executorch.task_registry import discover_tasks + discover_tasks() + + self.assertIn("image-text-to-text", task_registry) + + def test_transformers_version_requirement(self): + """Test that export requires proper transformers version""" + with patch("optimum.exporters.executorch.integrations.save_config_to_constant_methods") as mock_save: + with patch("optimum.exporters.executorch.integrations.is_transformers_version") as mock_version: + mock_save.return_value = {"get_max_seq_len": 2048} + mock_version.return_value = False # Simulate old transformers version + + module = ImageTextToTextExportableModule(self.mock_model) + + with self.assertRaises(ValueError) as context: + module.export() + + self.assertIn("transformers > 4.52.0", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file