Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,146 @@ def generate(self, prompt_token_ids, max_new_tokens):
break

return generated_ids


class ImageTextToTextExportableModule(torch.nn.Module):
"""
A wrapper module designed to make an image-text-to-text model exportable with `torch.export`.
This module ensures that the exported model is compatible with ExecuTorch.
"""

def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False):
super().__init__()
self.model = model
self.config = model.config
self.use_custom_kv_cache = use_custom_kv_cache
self.use_custom_sdpa = use_custom_sdpa
from .utils import save_config_to_constant_methods
self.metadata = save_config_to_constant_methods(
model.config.text_config, model.generation_config
)
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")

def _prepare_vision_embedding_export_inputs(self):
"""
Prepare example inputs and configurations for export.

Returns:
pixel_values (torch.Tensor): Example pixel values tensor.
dynamic_shapes (dict or None): Dynamic shape specifications for export.
strict (bool): Whether to use strict export mode.
"""
image_size = self.config.vision_config.image_size
pixel_values = torch.rand((1, 3, image_size, image_size))
dynamic_shapes = None
strict = False

return pixel_values, dynamic_shapes, strict

def _prepare_text_embedding_export_inputs(self):
"""
Prepare example inputs and configurations for export.

Returns:
inputs_embeds (torch.Tensor): Example inputs embeddings tensor.
cache_position (torch.Tensor): Example cache position tensor.
dynamic_shapes (dict or None): Dynamic shape specifications for export.
strict (bool): Whether to use strict export mode.
"""
# Prepare inputs with dynamic shapes
seq_length = 3 # Sequence length > 1 to avoid specialization issues
hidden_size = self.config.text_config.hidden_size
example_inputs_embeds = torch.zeros((1, seq_length, hidden_size), dtype=torch.float32)
example_cache_position = torch.arange(seq_length, dtype=torch.long)
max_seq_len = self.metadata.get("get_max_seq_len")
sliding_window = self.metadata.get("sliding_window", float("inf"))
max_dim = min(max_seq_len, sliding_window) - 1
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim)
dynamic_shapes = {
"inputs_embeds": {1: seq_len_dim},
"cache_position": {0: seq_len_dim},
}
strict = False

return example_inputs_embeds, example_cache_position, dynamic_shapes, strict

def export(
self,
) -> Dict[str, ExportedProgram]:
"""
Export both the vision encoder and text decoder components.

Returns:
Dict[str, ExportedProgram]: Dictionary containing exported programs for vision and text components.
"""
# Export vision encoder
pixel_values, vision_dynamic_shapes, vision_strict = self._prepare_vision_embedding_export_inputs()
logging.info(
f"Exporting vision encoder using pixel_values({pixel_values.shape}), dynamic_shapes={vision_dynamic_shapes}, strict={vision_strict}"
)

# Create vision encoder wrapper
vision_encoder = VisionEncoderExportableModule(self.model)
with torch.no_grad():
vision_exported_program = vision_encoder.export(pixel_values)["model"]

# Export text decoder
inputs_embeds, cache_position, text_dynamic_shapes, text_strict = self._prepare_text_embedding_export_inputs()
logging.info(
f"Exporting text decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape}), dynamic_shapes={text_dynamic_shapes}, strict={text_strict}"
)

# Use the enhanced transformers integration for multimodal support
if is_transformers_version(">", "4.52.0"):
from transformers.integrations.executorch import (
TorchExportableModuleForImageTextLM,
)

exportable_module = TorchExportableModuleForImageTextLM(
self.model.language_model,
max_batch_size=1,
max_cache_len=self.metadata.get("get_max_seq_len"),
)
self._register_attention_mask_for_4_53(exportable_module)

if self.use_custom_kv_cache:
from optimum.executorch.attentions.custom_kv_cache import (
replace_with_et_custom_kv_cache,
)

replace_with_et_custom_kv_cache(
exportable_module.model,
self.model.language_model.config,
self.model.generation_config,
self.model.dtype,
)

with torch.no_grad():
text_exported_program = exportable_module.export(inputs_embeds, cache_position, text_dynamic_shapes, text_strict)
else:
raise ValueError("Image-text-to-text export requires transformers > 4.52.0")

return {
"vision_encoder": vision_exported_program,
"text_decoder": text_exported_program
}

def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
"""Register attention mask for transformers >= 4.53.0"""
if is_transformers_version(">=", "4.53.0.dev0"):
from transformers.integrations.executorch import sdpa_mask_without_vmap
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
if self.use_custom_sdpa:
if self.use_custom_kv_cache:
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
else:
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa"
2 changes: 1 addition & 1 deletion optimum/exporters/executorch/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import causal_lm, image_classification, masked_lm, seq2seq_lm
from . import causal_lm, image_classification, image_text_to_text, masked_lm, seq2seq_lm
152 changes: 152 additions & 0 deletions optimum/exporters/executorch/tasks/image_text_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torchao
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig

from ..integrations import ImageTextToTextExportableModule
from ..quantization import quantize_model_
from ..task_registry import register_task


# NOTE: It's important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py.
# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier.
@register_task("image-text-to-text")
def load_image_text_to_text_model(model_name_or_path: str, **kwargs) -> ImageTextToTextExportableModule:
"""
Loads a causal language model for image-to-text generation and registers it under the task
'image-text-to-text' using Hugging Face's AutoModelForCausalLM.

Args:
model_name_or_path (str):
Model ID on huggingface.co or path on disk to the model repository to export. For example:
`model_name_or_path="google/gemma-3-4b-it"` or `model_name_or_path="/path/to/model_folder`
**kwargs:
Additional configuration options for the model:
- dtype (str, optional):
Data type for model weights (default: "float32").
Options include "float16" and "bfloat16".
- attn_implementation (str, optional):
Attention mechanism implementation (default: "sdpa").
- cache_implementation (str, optional):
Cache management strategy (default: "static").
- max_length (int, optional):
Maximum sequence length for generation (default: 2048).

Returns:
ImageTextToTextExportableModule:
An instance of `ImageTextToTextExportableModule` for exporting and lowering to ExecuTorch.
"""
device = "cpu"
batch_size = 1
dtype = kwargs.get("dtype", "float32")
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False)
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
cache_implementation = kwargs.get("cache_implementation", "static")
use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa"
max_length = kwargs.get("max_length", 2048)
config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path)

# Make sure config has text_config and vision_config:
if not hasattr(config, "text_config") or not hasattr(config, "vision_config"):
raise ValueError(
f"The model {model_name_or_path} does not have a `text_config` or `vision_config` attribute in its config. "
"This is required for image-text-to-text models."
)

if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
# the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite
# that function to avoid the data-dependent control flow.
config.rope_scaling["type"] = "default"

if hasattr(config, "use_cache") and config.use_cache is False:
config.use_cache = True

def _load_eager_pretrained(
model_name_or_path,
device,
dtype,
config,
attn_implementation,
cache_implementation,
batch_size,
max_length,
):
eager_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device,
torch_dtype=dtype,
config=config,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_length,
},
),
)
return eager_model

try:
eager_model = _load_eager_pretrained(
model_name_or_path,
device,
dtype,
config,
attn_implementation,
cache_implementation,
batch_size,
max_length,
)
except ValueError as e:
if "torch.nn.functional.scaled_dot_product_attention" in str(e):
logging.info("⚠ SDPA attention not supported, falling back to eager implementation")
attn_implementation = "eager"
eager_model = _load_eager_pretrained(
model_name_or_path,
device,
dtype,
config,
attn_implementation,
cache_implementation,
batch_size,
max_length,
)
else:
raise

# Make sure model has language_model as well as vision_tower:
if not hasattr(eager_model, "language_model") or not hasattr(eager_model, "vision_tower"):
raise ValueError(
f"The model {model_name_or_path} does not have a `language_model` or `vision_tower` attribute. "
"This is required for image-text-to-text models."
)

for param in eager_model.parameters():
# Must disable gradient for quantized checkpoint
if isinstance(param, torchao.utils.TorchAOBaseTensor):
param.requires_grad = False

qlinear_config = kwargs.get("qlinear", None)
qembedding_config = kwargs.get("qembedding", None)
quantize_model_(eager_model, qlinear_config=qlinear_config, qembedding_config=qembedding_config)

return ImageTextToTextExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa)
Loading
Loading