diff --git a/install_dev.py b/install_dev.py index 5cf313ff..6fa07cf3 100644 --- a/install_dev.py +++ b/install_dev.py @@ -5,21 +5,21 @@ def install_torch_nightly_deps(): """Install torch related dependencies from pinned nightly""" - EXECUTORCH_NIGHTLY_VERSION = "dev20250625" - TORCHAO_NIGHTLY_VERSION = "dev20250620" + EXECUTORCH_NIGHTLY_VERSION = "dev20250807" + TORCHAO_NIGHTLY_VERSION = "dev20250807" # Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/install_requirements.py#L74 - TORCH_NIGHTLY_VERSION = "dev20250601" + TORCH_NIGHTLY_VERSION = "dev20250807" subprocess.check_call( [ sys.executable, "-m", "pip", "install", - f"executorch==0.7.0.{EXECUTORCH_NIGHTLY_VERSION}", - f"torch==2.8.0.{TORCH_NIGHTLY_VERSION}", - f"torchvision==0.23.0.{TORCH_NIGHTLY_VERSION}", + f"executorch==0.8.0.{EXECUTORCH_NIGHTLY_VERSION}", + f"torch==2.9.0.{TORCH_NIGHTLY_VERSION}", + f"torchvision==0.24.0.{TORCH_NIGHTLY_VERSION}", f"torchaudio==2.8.0.{TORCH_NIGHTLY_VERSION}", - f"torchao==0.12.0.{TORCHAO_NIGHTLY_VERSION}", + f"torchao==0.13.0.{TORCHAO_NIGHTLY_VERSION}", "--extra-index-url", "https://download.pytorch.org/whl/nightly/cpu", ] diff --git a/optimum/executorch/__init__.py b/optimum/executorch/__init__.py index 07b72a6b..77fc3944 100644 --- a/optimum/executorch/__init__.py +++ b/optimum/executorch/__init__.py @@ -21,6 +21,7 @@ "modeling": [ "ExecuTorchModelForCausalLM", "ExecuTorchModelForImageClassification", + "ExecuTorchModelForMultimodalCausalLM", "ExecuTorchModelForMaskedLM", "ExecuTorchModelForSeq2SeqLM", "ExecuTorchModelForSpeechSeq2Seq", @@ -31,6 +32,7 @@ from .modeling import ( ExecuTorchModelForCausalLM, ExecuTorchModelForImageClassification, + ExecuTorchModelForMultimodalCausalLM, ExecuTorchModelForMaskedLM, ExecuTorchModelForSeq2SeqLM, ExecuTorchModelForSpeechSeq2Seq, diff --git a/optimum/executorch/attentions/custom_sdpa.py b/optimum/executorch/attentions/custom_sdpa.py index 59476c74..31202d47 100644 --- a/optimum/executorch/attentions/custom_sdpa.py +++ b/optimum/executorch/attentions/custom_sdpa.py @@ -59,7 +59,7 @@ def custom_sdpa_with_start_pos_forward( # Calculate the input pos from attention mask. # Branch out for float vs bool mask # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." - attention_mask = attention_mask.reshape(-1, max_seq_len) + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) first_row_mask = attention_mask[0, :] # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1 diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index ca3d3413..98e81923 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -25,6 +25,7 @@ from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from transformers import ( + AutoConfig, AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedLM, @@ -238,9 +239,12 @@ def _export( ) -> Dict[str, "ExecuTorchModule"]: task = kwargs.pop("task", None) if task is not None: - logger.warning(f"task was provided and set to {task} but not used, will be ignored") - inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) - logging.info(f"Inferred task from model class: {inferred_task}") + logger.warning(f"task was provided and set to {task}") + elif hasattr(cls, "task"): + task = cls.task + else: + task = TasksManager.infer_task_from_model(cls.auto_model_class) + logging.info(f"Inferred task from model class: {task}") save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -249,7 +253,7 @@ def _export( executorch_progs = main_export( model_name_or_path=model_id, output_dir=save_dir_path, - task=inferred_task, + task=task, recipe=recipe, config=config, subfolder=subfolder, @@ -309,6 +313,8 @@ def from_pretrained( model_dir = os.path.join(cached_model_dir, "snapshots", _revision) else: model_dir = model_id + if not config: + config = AutoConfig.from_pretrained(model_id) pte_files = find_files_matching_pattern( model_dir, @@ -1082,3 +1088,144 @@ def transcribe( self.stats.on_inference_end() self.stats.print_report() return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) + + +class ExecuTorchModelForMultimodalCausalLM(ExecuTorchModelBase): + """ + ExecuTorch model for CausalLM with multimodal capability. + + Although the auto_model_class is `AutoModelForCausalLM` same as `ExecuTorchModelForCausalLM`, this model is specifically designed for + multimodal-text-to-text tasks. This class provides an interface for loading, running, and generating outputs from a vision-language model + or a audio-language model optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models compatible + with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForCausalLM`. + model (`ExecuTorchModule`): + The loaded ExecuTorch model. + """ + + auto_model_class = AutoModelForCausalLM + + task = "multimodal-text-to-text" + + def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"): + super().__init__(models, config) + if not hasattr(self, "model"): + raise AttributeError("Expected attribute 'model' not found in the instance.") + + # Make sure config contains vision_config and text_config, otherwise raise an error + # TODO(jackzhxng): check for audio config as well. + if not hasattr(config, "vision_config") or not hasattr(config, "text_config"): + raise ValueError( + "The configuration must contain 'vision_config' and 'text_config' attributes for image-text-to-text task." + ) + metadata = self.model.method_names() + logging.debug(f"Load all static methods: {metadata}") + if "use_kv_cache" in metadata: + self.use_kv_cache = self.model.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.model.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.model.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.model.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.model.run_method("get_bos_id")[0] + for key in ("get_eos_id", "get_eos_ids"): + if key in metadata: + self.eos_token_ids = self.model.run_method(key) + break + if "get_vocab_size" in metadata: + self.vocab_size = self.model.run_method("get_vocab_size")[0] + if "use_sdpa_with_kv_cache" in metadata: + self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0] + + def forward( + self, + cache_position: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. Here we are assuming pixel_values only represent 1 image. + TODO(jackzhxng): Support `input_features` for audio modality. + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the model. + pixel_values (`torch.Tensor`): Tensor representing image input to the model. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + if (input_ids is None) and (pixel_values is None): + raise ValueError("You must specify at least one of input_ids or pixel_values") + self.stats.on_model_execution_start() + + inputs_embeds = self.model.run_method("token_embedding", (input_ids,))[0] + + if pixel_values is not None: + image_features = self.model.run_method("image_encoder", (pixel_values,))[0] + + if input_ids is None: + special_image_mask = ( + inputs_embeds + == self.model.run_method( + "token_embedding", + (torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device),), + )[0] + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + logits = self.model.run_method("text_model", (cache_position, inputs_embeds))[0] + self.stats.on_model_execution_end() + return logits + + def generate( + self, + tokenizer: "PreTrainedTokenizer", + input_ids: torch.LongTensor, + pixel_values: Optional[torch.FloatTensor] = None, + max_new_tokens: int = 100, + ): + # Sanity check + + if max_new_tokens <= 0: + raise ValueError(f"max_new_tokens must be greater than 0, got {max_new_tokens}.") + elif max_new_tokens > self.max_cache_size: + logging.warning( + f"max_new_tokens={max_new_tokens} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_new_tokens = self.max_cache_size + + # Prefill + logits = self.forward( + input_ids=input_ids, + pixel_values=pixel_values, + cache_position=torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device), + ) + + tokens = [] + + token = torch.argmax(logits[:, -1, :], dim=-1).item() + tokens.append(token) + i = 1 + while i < max_new_tokens: + # Generate next token + logits = self.forward( + input_ids=torch.tensor([token], dtype=torch.long, device=input_ids.device).unsqueeze(0), + cache_position=torch.tensor([input_ids.size(1) + i - 1], dtype=torch.long, device=input_ids.device), + ) + token = torch.argmax(logits[:, -1, :], dim=-1).item() + tokens.append(token) + + if token in self.eos_token_ids: + break + i += 1 + + return tokenizer.decode(tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 23e6819a..81d3f742 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict +from typing import Dict, Optional import torch from packaging.version import parse @@ -26,6 +26,10 @@ T5ForConditionalGeneration, WhisperForConditionalGeneration, ) +from transformers.cache_utils import HybridCache +from transformers.integrations.executorch import sdpa_mask_without_vmap +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from transformers.generation.configuration_utils import GenerationConfig from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache @@ -34,6 +38,386 @@ from .utils import save_config_to_constant_methods +class TorchExportableModuleWithHybridCache(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for decoder-only LM to `HybridCache`. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + Raises: + AssertionError: If the model doesn't have the expected configuration for HybridCache. + """ + super().__init__() + self.model = model + + # Verify the model is configured for HybridCache + if not self.model.config.text_config.use_cache: + raise AssertionError("Model must have caching enabled") + + # Initialize the HybridCache + self.cache = HybridCache( + config=self.model.config.text_config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.model.device, + dtype=self.model.dtype, + ) + + # Register all key and value cache tensors as buffers + for i in range(len(self.cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) + + def forward( + self, + *, + cache_position: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + Args: + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`): Optional tensor representing input embeddings. + Returns: + torch.Tensor: Logits output from the model. + """ + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + + # Generate position_ids from cache_position + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + + # Forward pass with the model + outputs = self.model( + input_ids=input_ids, + attention_mask=None, + position_ids=position_ids, + past_key_values=self.cache, + use_cache=True, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + ) + + # Return only the logits to simplify the export + return outputs.logits + + +class TorchExportableModuleForImageTextLM(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for image-text LM with cache. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + Raises: + ValueError: If the model is configured with a unsupported cache implementation. + """ + super().__init__() + + if not hasattr(model.config.text_config, "use_cache") or model.config.text_config.use_cache is False: + raise ValueError("The model must have caching enabled to be performant.") + + if ( + hasattr(model.config.text_config, "layer_types") + and getattr(model.config.text_config, "sliding_window", None) is not None + ): + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + else: + # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, + # there is only 1 type of layers, so export will use `StaticCache` by default. + raise NotImplementedError("Using `StaticCache` for exporting image-text LM is not implemented yet.") + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap" + + def forward( + self, + input_embeds: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + Args: + input_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + Returns: + torch.Tensor: Logits output from the model. + """ + return self.model.forward(input_embeds, cache_position) + + def export( + self, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the wrapped module using `torch.export`. + Args: + input_embeds (`Optional[torch.Tensor]`): + Tensor representing current input embeddings to the module. If not provided, a default tensor will be used. + cache_position (`Optional[torch.Tensor]`): + Tensor representing current input position in the cache. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + seq_length = 3 + + if dynamic_shapes is None: + seq_len_dim = torch.export.Dim("seq_length_dim", max=seq_length) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"cache_position": cache_position, "inputs_embeds": inputs_embeds}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + return exported_program + + +class ImageEncoderExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a vision encoder-only model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, pixel_values): + """ + Projects the last hidden state from the vision model into language model space. + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.model.multi_modal_projector(vision_outputs) + return image_features + + +class MultimodalTextToTextExportableModule(torch.nn.Module): + """ + A wrapper module designed to make an image-text-to-text model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False): + super().__init__() + self.model = model + self.config = model.config + self.use_custom_kv_cache = use_custom_kv_cache + self.use_custom_sdpa = use_custom_sdpa + self.metadata = save_config_to_constant_methods(model.config.text_config, model.generation_config) + logging.info(f"Metadata to be recorded in PTE: {self.metadata}") + + def _prepare_vision_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + Returns: + pixel_values (torch.Tensor): Example pixel values tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + image_size = self.config.vision_config.image_size + pixel_values = torch.rand((1, 3, image_size, image_size)) + dynamic_shapes = None + strict = False + + return pixel_values, dynamic_shapes, strict + + def _prepare_text_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + Returns: + input_ids (torch.Tensor): Example input IDs tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + # Prepare inputs with dynamic shapes + seq_length = 3 # Sequence length > 1 to avoid specialization issues + example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + } + strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994 + return example_input_ids, dynamic_shapes, strict + + def _prepare_decoder_only_export_inputs(self): + """ + Prepare example inputs and configurations for export. + Returns: + inputs_embeds (torch.Tensor): Example input embeddings tensor. + cache_position (torch.Tensor): Example cache position tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + + # Prepare inputs with dynamic shapes + seq_length = 3 + example_inputs_embeds = torch.zeros((1, seq_length, self.config.text_config.hidden_size), dtype=torch.float) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994 + return example_inputs_embeds, example_cache_position, dynamic_shapes, strict + + def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module): + from transformers.integrations.executorch import sdpa_mask_without_vmap + from transformers.masking_utils import AttentionMaskInterface + from transformers.modeling_utils import AttentionInterface + + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) + if self.use_custom_sdpa: + if self.use_custom_kv_cache: + AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) + AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" + else: + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = "custom_sdpa" + + def export( + self, + ) -> Dict[str, ExportedProgram]: + + exportable_module = TorchExportableModuleForImageTextLM( + self.model, + max_batch_size=1, + max_cache_len=self.metadata.get("get_max_seq_len"), + ) + self._register_attention_mask_for_4_53(exportable_module) + + if self.use_custom_kv_cache: + from optimum.executorch.attentions.custom_kv_cache import ( + replace_with_et_custom_kv_cache, + ) + + replace_with_et_custom_kv_cache( + exportable_module.model, + self.model.config.text_config, + self.model.generation_config, + self.model.dtype, + ) + + with torch.no_grad(): + inputs_embeds, cache_position, dynamic_shapes, strict = self._prepare_decoder_only_export_inputs() + logging.info( + f"Exporting decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + exported_program = exportable_module.export(inputs_embeds, cache_position, dynamic_shapes, strict) + # Apply RemoveTransposes pass to remove + # any back-to-back transpose ops that are not needed + # e.g. output of update_cache is transposed and + # input to custom_sdpa is transposed. + from executorch.extension.llm.export.export_passes import ( + RemoveRedundantTransposes, + ) + + mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0] + exported_program = torch.export.export( + mutated_gm, + args=(), + kwargs={ + "cache_position": cache_position, + "inputs_embeds": inputs_embeds, + }, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + + # Export token embeddings + input_ids, dynamic_shapes, strict = self._prepare_text_embedding_export_inputs() + logging.info( + f"Exporting token embeddings using input_ids({input_ids.shape}), dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + + token_embeddings_exported_program = torch.export.export( + exportable_module.model.model.language_model.get_input_embeddings(), + args=(input_ids,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + + exported_programs = { + "text_model": exported_program, + "token_embedding": token_embeddings_exported_program, + } + + if hasattr(exportable_module.model.model, "vision_tower"): + # Export vision embeddings + pixel_values, dynamic_shapes, strict = self._prepare_vision_embedding_export_inputs() + logging.info( + f"Exporting vision embeddings using pixel_values({pixel_values.shape}), dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + # Setting the _attn_implementation to "sdpa_without_vmap" for vision encoder + exportable_module.model.model.vision_tower.config._attn_implementation = "sdpa_without_vmap" + vision_encoder = ImageEncoderExportableModule(exportable_module.model.model) + vision_embeddings_exported_program = torch.export.export( + vision_encoder, + args=(pixel_values,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + exported_programs["image_encoder"] = vision_embeddings_exported_program + # These keys need to match with the runner + return exported_programs + + class CausalLMExportableModule(torch.nn.Module): """ A wrapper module designed to make a Causal LM model exportable with `torch.export`. diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index 0bd7b374..08982bba 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -28,19 +28,24 @@ ExecutorchProgram, to_edge_transform_and_lower, ) +from executorch.exir.passes import MemoryPlanningPass + from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass from ..integrations import ( CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, + ImageTextToTextExportableModule, ) from ..recipe_registry import register_recipe @register_recipe("xnnpack") def export_to_executorch_with_xnnpack( - model: Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule], + model: Union[ + CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, ImageTextToTextExportableModule + ], **kwargs, ): """ @@ -49,7 +54,7 @@ def export_to_executorch_with_xnnpack( This function also write metadata required by the ExecuTorch runtime to the model. Args: - model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule]): + model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, ImageTextToTextExportableModule]): The PyTorch model to be exported to ExecuTorch. **kwargs: Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs. @@ -64,36 +69,32 @@ def _lower_to_executorch( exported_programs: Dict[str, ExportedProgram], metadata=None, ) -> Dict[str, ExecutorchProgram]: - et_progs = {} backend_config_dict = { "extract_delegate_segments": True, + "memory_planning_pass": MemoryPlanningPass(alloc_graph_input=False), } if parse(executorch_version.__version__).base_version > "0.6.0": backend_config_dict["do_quant_fusion_and_const_prop"] = True - - for pte_name, exported_program in exported_programs.items(): - logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}") - et_progs[pte_name] = to_edge_transform_and_lower( - exported_program, - partitioner=[XnnpackPartitioner()], - compile_config=EdgeCompileConfig( - _check_ir_validity=False, - _skip_dim_order=True, - ), - constant_methods=metadata, - transform_passes=[RemovePaddingIdxEmbeddingPass()], - ).to_executorch( - config=ExecutorchBackendConfig(**backend_config_dict), - ) - logging.debug( - f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}" - ) - delegation_info = get_delegation_info(et_progs[pte_name].exported_program().graph_module) - logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}") - logging.debug( - f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}" - ) - return et_progs + pte_name = model.model.config.model_type + logging.debug(f"\nExported program for {pte_name}.pte: {exported_programs}") + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ).to_executorch( + config=ExecutorchBackendConfig(**backend_config_dict), + ) + delegation_info = get_delegation_info(et_prog.exported_program(list(exported_programs.keys())[0]).graph_module) + logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}") + logging.debug( + f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}" + ) + return {pte_name: et_prog} exported_progs = model.export() diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py new file mode 100644 index 00000000..c197d304 --- /dev/null +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -0,0 +1,153 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +import torchao +from packaging.version import parse +from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig + +from ..integrations import MultimodalTextToTextExportableModule +from ..task_registry import register_task + + +# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +@register_task("image-text-to-text") +def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for image-to-text generation and registers it under the task + 'image-text-to-text' using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="google/gemma-3-4b-it"` or `model_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + ImageTextToTextExportableModule: + An instance of `ImageTextToTextExportableModule` for exporting and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + use_custom_sdpa = kwargs.get("use_custom_sdpa", False) + use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False) + attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa" + max_length = kwargs.get("max_length", 2048) + config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path) + + # Make sure config has text_config and vision_config: + if not hasattr(config, "text_config") or not hasattr(config, "vision_config"): + raise ValueError( + f"The model {model_name_or_path} does not have a `text_config` or `vision_config` attribute in its config. " + "This is required for image-text-to-text models." + ) + + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting + # the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite + # that function to avoid the data-dependent control flow. + config.rope_scaling["type"] = "default" + + if hasattr(config, "use_cache") and config.use_cache is False: + config.use_cache = True + + eager_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + config=config, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) + + # Make sure model has language_model as well as vision_tower: + if not hasattr(eager_model, "language_model") or not hasattr(eager_model, "vision_tower"): + raise ValueError( + f"The model {model_name_or_path} does not have a `language_model` or `vision_tower` attribute. " + "This is required for image-text-to-text models." + ) + + for param in eager_model.parameters(): + # Must disable gradient for quantized checkpoint + if isinstance(param, torchao.utils.TorchAOBaseTensor): + param.requires_grad = False + + # TODO: Move quantization recipe out for better composability. + # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed. + qlinear_config = kwargs.get("qlinear", None) + qembedding_config = kwargs.get("qembedding", None) + if qlinear_config or qembedding_config: + # TODO: Update torchao to use 0.11.0 once released + if parse(torchao.__version__) < parse("0.11.0.dev0"): + raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.") + + from torchao.quantization.granularity import PerAxis, PerGroup + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, + ) + from torchao.utils import unwrap_tensor_subclass + + if qembedding_config: + logging.info("Quantizing embedding layers.") + # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available. + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + quantize_( + eager_model, + embedding_config, + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + if qlinear_config: + logging.info("Quantizing linear layers.") + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + ) + quantize_( + eager_model.language_model, + linear_config, + ) + + unwrap_tensor_subclass(eager_model) + + return MultimodalTextToTextExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa) diff --git a/tests/models/test_modeling_gemma3.py b/tests/models/test_modeling_gemma3.py index 77947c8a..4830f433 100644 --- a/tests/models/test_modeling_gemma3.py +++ b/tests/models/test_modeling_gemma3.py @@ -22,14 +22,15 @@ import unittest import pytest +import torch import torchao import transformers from executorch.extension.pybindings.portable_lib import ExecuTorchModule from packaging.version import parse -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor from transformers.testing_utils import slow -from optimum.executorch import ExecuTorchModelForCausalLM +from optimum.executorch import ExecuTorchModelForCausalLM, ExecuTorchModelForMultimodalCausalLM from optimum.utils.import_utils import is_transformers_version from ..utils import check_causal_lm_output_quality @@ -267,3 +268,61 @@ def test_gemma3_text_generation_with_custom_sdpa_kv_cache_8da4w_8we(self): gc.collect() self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif( + parse(transformers.__version__) < parse("4.53.0.dev0") or parse(torchao.__version__) < parse("0.11.0"), + reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0", + ) + @pytest.mark.skipif(is_linux_ci, reason="OOM on linux runner") + def test_gemma3_image_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we(self): + + model_id = "google/gemma-3-4b-it" + + model = ExecuTorchModelForMultimodalCausalLM.from_pretrained( + model_id, + recipe="xnnpack", + task="image-text-to-text", + export=True, + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear=True, + qembedding_config=True, + ) + + # Generate + image_url = "https://llava-vl.github.io/static/images/view.jpg" + conversation = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + { + "type": "text", + "text": "What are the things I should be cautious about when I visit here?", + }, + ], + }, + ] + processor = AutoProcessor.from_pretrained(model_id) + inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + output = model.generate( + AutoTokenizer.from_pretrained(model_id), + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=50, + ) + self.assertEqual( + output, + """Okay, let's analyze the image and discuss potential cautions for visiting this location. + +Based on the picture, we're looking at a serene lakeside scene with a wooden pier extending into the water. Here's a breakdown of what you""", + )