diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 7cfd0ac5fc6..cbdbaa7ee3b 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -797,6 +797,8 @@ jobs: --etdump_path ${OUTPUT_DIR}/etdump.etdp \ --tsv_path ${TSV_PATH} + echo "::group::Run Multimodal Tests" + python3 -m unittest extension/llm/optimum/test/test_modeling_gemma3.py echo "::endgroup::" diff --git a/extension/llm/optimum/README.md b/extension/llm/optimum/README.md new file mode 100644 index 00000000000..e3c95f25b4d --- /dev/null +++ b/extension/llm/optimum/README.md @@ -0,0 +1,73 @@ +# ExecuTorch Optimum Module + +This module provides integration utilities for exporting and optimizing transformer models for ExecuTorch runtime. It contains specialized wrapper classes and utilities to make pre-trained models from Hugging Face Transformers compatible with `torch.export` and ExecuTorch execution. A lot of code is forked from `optimum-executorch` and adopted from `transformers`. We put it in ExecuTorch so that we can fast iterate on the stack. Eventually we want to upstream changes to `transformers` and `optimum-executorch`. + +## Overview + +The optimum module bridges the gap between Hugging Face Transformers models and ExecuTorch by providing: + +- Exportable wrapper modules for different model types +- Custom cache implementations for efficient inference +- Utilities for model configuration and optimization +- Integration with ExecuTorch's custom operators + +## Key Components + +### Exportable Modules + +#### `TorchExportableModuleWithHybridCache` +A wrapper module that makes decoder-only language models exportable with `torch.export` using `HybridCache`. This is a forked version of [`TorchExportableModuleForDecoderOnlyLM`](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L391) with some modifications to support `inputs_embeds`. + +**Note**: This class should be upstreamed to transformers. We keep it here so that we can iterate quickly. + +#### `TorchExportableModuleForImageTextLM` +A wrapper for text decoder model in a vision-language model. It is very similar to [`TorchExportableModuleForDecoderOnlyLM`](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L30) but instead of taking `input_ids` this module takes `inputs_embeds`. This is because we want to be able to take both token embeddings and image embeddings as inputs. + +**Note**: This class should be upstreamed to transformers. We keep it here so that we can iterate quickly. + +#### `ImageEncoderExportableModule` +A wrapper for vision encoder models that projects vision features to language model space. Commonly implemented as `get_image_features()` in HuggingFace transformers. For example: [`Gemma3Model.get_image_features()`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/modeling_gemma3.py#L794). + +#### `ImageTextToTextExportableModule` +A wrapper of `torch.nn.Module` for `image-text-to-text` task. Provides `export()` API that generates an `ExportedProgram`. It will be consumed by `xnnpack.py` recipe to generate ExecuTorch program. + +### Custom Implementations +These are mostly copied from `optimum-executorch`. We put them here so that they can be reused by `integrations.py` and `xnnpack.py` recipe. + +- **Custom KV Cache**: Optimized key-value cache implementations for ExecuTorch +- **Custom SDPA**: Scaled Dot-Product Attention optimizations +- **XNNPACK Integration**: Lower to XNNPACK backend for optimized inference on CPU + +### Utilities + +- Configuration saving and constant method generation +- Model metadata extraction +- Export helper functions + +## Usage + +```python +from transformers import PretrainedConfig +from executorch.extension.llm.optimum.image_text_to_text import load_image_text_to_text_model +from executorch.extension.llm.optimum.xnnpack import export_to_executorch_with_xnnpack +from executorch.extension.llm.optimum.modeling import ExecuTorchModelForImageTextToTextCausalLM + +model_id = "google/gemma-3-4b-it" + +module = load_image_text_to_text_model( + model_id, + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear=True, + qembedding=True, +) +model = export_to_executorch_with_xnnpack(module) +et_model = ExecuTorchModelForImageTextToTextCausalLM(model, PretrainedConfig.from_pretrained(model_id)) +``` + +## Testing + +Run tests with: +```bash +python -m pytest extension/llm/optimum/test/ +``` diff --git a/extension/llm/optimum/__init__.py b/extension/llm/optimum/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/extension/llm/optimum/custom_kv_cache.py b/extension/llm/optimum/custom_kv_cache.py new file mode 100644 index 00000000000..538772df43d --- /dev/null +++ b/extension/llm/optimum/custom_kv_cache.py @@ -0,0 +1,421 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch + + +# If transformers is not installed, raise an ImportError +try: + from transformers.cache_utils import HybridCache, StaticCache +except ImportError: + raise ImportError( + "transformers is not installed. Please install it to use Static/HybridCache." + ) + +try: + from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomKVCache, + CustomRingKVCache, + ) +except ImportError: + raise ImportError( + "ExecutorTorch is not installed. Please install it to use Custom Cache." + ) + + +class ETCustomStaticCache(StaticCache): + """ + Custom KV Cache implementation for ExecutorTorch that inherits from Hugging Face's StaticCache + but uses custom operations for cache updates similar to ExecutorTorch's CustomStaticCache. + """ + + def __init__( + self, + config, + max_batch_size: int, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.float32, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ): + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + ) + + # make sure layer_device_map is none + assert layer_device_map is None + assert device is None or device == "cpu", "Device must be None or 'cpu'" + + # Create a list of CustomKVCache instances, one per layer + self.kv_cache = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + layer_cache = CustomKVCache( + max_batch_size=self.max_batch_size, + max_context_length=self.max_cache_len, + n_heads=self.num_key_value_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self.kv_cache.append(layer_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` + using ExecutorTorch's CustomKVCache. + + Args: + key_states (`torch.Tensor`): + The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + value_states (`torch.Tensor`): + The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache update. + + Returns: + A tuple containing the updated key and value states. + """ + assert cache_kwargs is not None + + # Get cache position from cache_kwargs (used by StaticCache) + cache_position = cache_kwargs.get("cache_position") + assert cache_position is not None + assert isinstance(cache_position, torch.Tensor) + + # Get the CustomKVCache instance for this layer + layer_cache = self.kv_cache[layer_idx] + + # Use the CustomKVCache's update method + # CustomKVCache expects input_pos, k_val, v_val and handles the transpose internally + k_out, v_out = layer_cache.update( + input_pos=cache_position, + k_val=key_states, + v_val=value_states, + ) + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # Occupied cache == any slot in the 2nd dim (sequence length) holds a non-zero value + # This is different from StaticCache which checks the 3rd dim + if layer_idx is None: + layer_idx = 0 + return (self.kv_cache[layer_idx].k_cache[0, :, 0].any(dim=-1)).sum() + + @classmethod + def from_legacy_cache( + cls, + config, + legacy_cache, + max_cache_len=None, + device=None, + dtype=None, + ): + """ + Create an ETCustomStaticCache from a legacy cache implementation. + + Args: + config: The model configuration + legacy_cache: The legacy cache implementation + max_cache_len: The maximum cache length + device: The device for the new cache + dtype: The data type for the new cache + + Returns: + A new ETCustomStaticCache instance + """ + assert hasattr(legacy_cache, "k_cache") and hasattr(legacy_cache, "v_cache") + # Extract dimensions from the legacy cache + assert len(legacy_cache.k_cache.shape) == 4 + if legacy_cache.k_cache.shape[1] == legacy_cache.n_heads: + # Shape is [batch_size, n_heads, seq_len, head_dim] + max_batch_size = legacy_cache.k_cache.shape[0] + else: + # Shape is [batch_size, seq_len, n_heads, head_dim] + max_batch_size = legacy_cache.k_cache.shape[0] + + # Use the legacy cache's device and dtype if not specified + if device is None and hasattr(legacy_cache, "device"): + device = legacy_cache.device + elif device is None and hasattr(legacy_cache.k_cache, "device"): + device = legacy_cache.k_cache.device + + if dtype is None and hasattr(legacy_cache, "dtype"): + dtype = legacy_cache.dtype + elif dtype is None and hasattr(legacy_cache.k_cache, "dtype"): + dtype = legacy_cache.k_cache.dtype + + assert device is None or device == "cpu" + assert dtype is None or dtype == torch.float32 + + # Use the legacy cache's max_seq_len if max_cache_len is not specified + if max_cache_len is None and hasattr(legacy_cache, "max_seq_len"): + max_cache_len = legacy_cache.max_seq_len + elif max_cache_len is None and hasattr(legacy_cache, "max_cache_len"): + max_cache_len = legacy_cache.max_cache_len + + return cls( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + ) + + +# Need to figure out if I have to inherit from HybridCache or StaticCache +class ETCustomHybridCache(HybridCache): + """ + Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache + but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers. + """ + + def __init__( + self, + config, + max_batch_size: int, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.float32, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ): + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + ) + + # make sure layer_device_map is none + assert layer_device_map is None + assert device is None or device == "cpu", "Device must be None or 'cpu'" + + self.cache_position = None + # Create a list of cache instances, one per layer + # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers + self.kv_cache = torch.nn.ModuleList() + for layer_idx in range(config.num_hidden_layers): + # newer version of transfomer has is_sliding defined + # for HybridCache + if self.is_sliding[layer_idx]: + # This is a sliding window layer + layer_cache = CustomRingKVCache( + max_batch_size=self.max_batch_size, + max_context_length=self.sliding_window_len, + n_heads=self.num_key_value_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + else: + layer_cache = CustomKVCache( + max_batch_size=self.max_batch_size, + max_context_length=self.max_cache_len, + n_heads=self.num_key_value_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self.kv_cache.append(layer_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` + using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type. + + Args: + key_states (`torch.Tensor`): + The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + value_states (`torch.Tensor`): + The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache update. + + Returns: + A tuple containing the updated key and value states. + """ + assert cache_kwargs is not None + + # Get cache position from cache_kwargs (used by HybridCache) + cache_position = cache_kwargs.get("cache_position") + assert cache_position is not None + assert isinstance(cache_position, torch.Tensor) + self.cache_position = cache_position + + # Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache) + layer_cache = self.kv_cache[layer_idx] + + # Use the cache's update method + # Both CustomKVCache and CustomRingKVCache have the same update interface + k_out, v_out = layer_cache.update( + input_pos=cache_position, + k_val=key_states, + v_val=value_states, + ) + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if layer_idx is None: + layer_idx = 0 + + # For CustomRingKVCache, we need to handle the sequence length differently + layer_cache = self.kv_cache[layer_idx] + if self.is_sliding[layer_idx]: + # CustomRingKVCache cache_position_manager which + # maintains cache position for each slot in the kv cache + # we return the max position + 1 to indicate max position + # seen so far. Not sure if thats the correct interpretation + # of sequence length + return layer_cache.cache_positions_manager.cache_positions.max().item() + 1 + return (layer_cache.k_cache[0, :, 0].any(dim=-1)).sum() + + def get_layer_cache(self, layer_idx: int): + """ + Get the cache for a specific layer. This method is dynamo-traceable. + + Args: + layer_idx (int): The layer index + + Returns: + The cache instance for the specified layer (CustomKVCache or CustomRingKVCache) + """ + return self.kv_cache[layer_idx] + + +def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): + """ + Replace all KV caches in the module with ETCustomStaticCache. + This modifies the model in place. + + Args: + module: The module to modify + config: The model configuration + + Returns: + The modified module + """ + # Recursively replace KV caches + return _replace_with_et_custom_kv_cache( + module, config, generation_config, cache_dtype + ) + + +def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): + """ + Helper function to recursively replace KV caches in the module. + + Args: + module: The module to modify + config: The model configuration + + Returns: + The modified module + """ + # Check if module has static_cache (TorchExportableModuleWithStaticCache) + if hasattr(module, "static_cache"): + assert isinstance( + module.static_cache, StaticCache + ), f"Expected StaticCache, got {type(module.static_cache)}" + + # TODO: Add replace_cache to exported module + # in transformer's executorch.py + if getattr(module, "replace_cache", None) is not None: + static_cache = ETCustomStaticCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + module.replace_cache(static_cache) + else: + module.static_cache = ETCustomStaticCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + # Dont know why we need to this even though + # CustomKVCache registers the attributes + for i in range(len(module.static_cache.kv_cache)): + setattr( + module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache + ) + setattr( + module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache + ) + + # Check if module has cache (TorchExportableModuleWithHybridCache) + elif hasattr(module, "cache"): + assert isinstance( + module.cache, HybridCache + ), f"Expected HybridCache, got {type(module.cache)}" + + # Replace with ETCustomHybridCache + if getattr(module, "replace_cache", None) is not None: + hybrid_cache = ETCustomHybridCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + module.replace_cache(hybrid_cache) + else: + module.cache = ETCustomHybridCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + # Register cache attributes for each layer + for i in range(len(module.cache.kv_cache)): + setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache) + setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache) + if module.cache.is_sliding[i]: + # Register cache_positions as buffer for sliding window layers + # This prevents it from being traced as a constant + module.register_buffer( + f"cache_positions_{i}", + module.cache.kv_cache[ + i + ].cache_positions_manager.cache_positions, + persistent=False, + ) + else: + raise ValueError( + "Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) " + "or 'cache' (TorchExportableModuleWithHybridCache) attribute" + ) + + return module diff --git a/extension/llm/optimum/custom_sdpa.py b/extension/llm/optimum/custom_sdpa.py new file mode 100644 index 00000000000..eb28069d3e4 --- /dev/null +++ b/extension/llm/optimum/custom_sdpa.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable, Optional, Tuple, Union + +import torch +from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa + + +def custom_sdpa_with_start_pos_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + # FA2 uses non-transposed inputs + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Convert the hell out of the inputs to fp32 and back + input_dtype = query.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + + # Ignore the causal flag from kwargs but use the one in module + kwargs.pop("is_causal", None) + assert module.is_causal, "Current variant supports only causal attention" + + is_causal = module.is_causal + if kwargs.get("is_sliding", False): + is_causal = False + attn_mask = attention_mask + # start_pos is not important when using mask + # instead of doing causal attention + start_pos = 0 + else: + attn_mask = None + # Calculate the input pos from attention mask. + # Branch out for float vs bool mask + # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) + first_row_mask = attention_mask[0, :] + # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 + start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1 + + output = torch.ops.llama.custom_sdpa( + query, + key, + value, + start_pos=start_pos, + attn_mask=attn_mask, + drpout_p=0.0, + is_causal=is_causal, + scale=scaling, + ) + return output.to(input_dtype), None + + +def get_custom_sdpa_for_ring_kv_cache( + exportable_module: torch.nn.Module, +) -> Callable: + # lazy importing to avoid version dependent class definition + from executorch import version + + try: + from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomRingKVCache, + ) + except ImportError: + raise ImportError( + f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch." + ) + + def _custom_sdpa_for_ring_kv_cache( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + is_sliding = getattr(module, "is_sliding", False) + if is_sliding: + # lazy import to avoid being in the optimum import path + # for et <= 0.6.0 version + from optimum.executorch.attentions.custom_kv_cache import ( + ETCustomHybridCache, + ) + + layer_idx = module.layer_idx + assert ( + layer_idx is not None + ), "layer_idx is not set for sliding window attention." + hybrid_cache = exportable_module.model.cache + assert isinstance( + hybrid_cache, ETCustomHybridCache + ), f"Expected HybridCache, got {type(hybrid_cache)}" + ring_cache = hybrid_cache.get_layer_cache(layer_idx) + assert isinstance( + ring_cache, CustomRingKVCache + ), f"Expected CustomRingKVCache, got {type(ring_cache)}" + input_pos = hybrid_cache.cache_position[0].item() + seqlen = query.shape[2] + attention_mask = ring_cache.create_causal_mask_for_ring_buffer( + input_pos, seqlen + ) + kwargs.update({"is_sliding": True}) + return custom_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + scaling, + softcap, + head_mask, + **kwargs, + ) + else: + return custom_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + scaling, + softcap, + head_mask, + **kwargs, + ) + + return _custom_sdpa_for_ring_kv_cache diff --git a/extension/llm/optimum/image_text_to_text.py b/extension/llm/optimum/image_text_to_text.py new file mode 100644 index 00000000000..d8d44a2d1a2 --- /dev/null +++ b/extension/llm/optimum/image_text_to_text.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import logging + +import torch +import torchao +from packaging.version import parse +from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig + +from executorch.extension.llm.optimum.integrations import ImageTextToTextExportableModule + + +# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +def load_image_text_to_text_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for image-to-text generation and registers it under the task + 'image-text-to-text' using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="google/gemma-3-4b-it"` or `model_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + ImageTextToTextExportableModule: + An instance of `ImageTextToTextExportableModule` for exporting and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + use_custom_sdpa = kwargs.get("use_custom_sdpa", False) + use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False) + attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa" + max_length = kwargs.get("max_length", 2048) + config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path) + + # Make sure config has text_config and vision_config: + if not hasattr(config, "text_config") or not hasattr(config, "vision_config"): + raise ValueError( + f"The model {model_name_or_path} does not have a `text_config` or `vision_config` attribute in its config. " + "This is required for image-text-to-text models." + ) + + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting + # the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite + # that function to avoid the data-dependent control flow. + config.rope_scaling["type"] = "default" + + if hasattr(config, "use_cache") and config.use_cache is False: + config.use_cache = True + + eager_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + config=config, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) + + # Make sure model has language_model as well as vision_tower: + if not hasattr(eager_model, "language_model") or not hasattr(eager_model, "vision_tower"): + raise ValueError( + f"The model {model_name_or_path} does not have a `language_model` or `vision_tower` attribute. " + "This is required for image-text-to-text models." + ) + + for param in eager_model.parameters(): + # Must disable gradient for quantized checkpoint + if isinstance(param, torchao.utils.TorchAOBaseTensor): + param.requires_grad = False + + # TODO: Move quantization recipe out for better composability. + # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed. + qlinear_config = kwargs.get("qlinear", None) + qembedding_config = kwargs.get("qembedding", None) + if qlinear_config or qembedding_config: + # TODO: Update torchao to use 0.11.0 once released + if parse(torchao.__version__) < parse("0.11.0.dev0"): + raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.") + + from torchao.quantization.granularity import PerAxis, PerGroup + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, + ) + from torchao.utils import unwrap_tensor_subclass + + if qembedding_config: + logging.info("Quantizing embedding layers.") + # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available. + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + quantize_( + eager_model, + embedding_config, + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + if qlinear_config: + logging.info("Quantizing linear layers.") + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + ) + quantize_( + eager_model.language_model, + linear_config, + ) + + unwrap_tensor_subclass(eager_model) + + return ImageTextToTextExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa) diff --git a/extension/llm/optimum/integrations.py b/extension/llm/optimum/integrations.py new file mode 100644 index 00000000000..b199ea20e63 --- /dev/null +++ b/extension/llm/optimum/integrations.py @@ -0,0 +1,460 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from typing import Dict, Optional + +import torch + +from executorch.extension.llm.optimum.custom_sdpa import ( + get_custom_sdpa_for_ring_kv_cache, +) + +from executorch.extension.llm.optimum.utils import save_config_to_constant_methods +from packaging.version import parse +from torch.export import ExportedProgram +from transformers import PreTrainedModel +from transformers.cache_utils import HybridCache +from transformers.integrations.executorch import sdpa_mask_without_vmap +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +class TorchExportableModuleWithHybridCache(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for decoder-only LM to `HybridCache`. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + AssertionError: If the model doesn't have the expected configuration for HybridCache. + """ + super().__init__() + self.model = model + + # Verify the model is configured for HybridCache + if not self.model.config.text_config.use_cache: + raise AssertionError("Model must have caching enabled") + + # Initialize the HybridCache + self.cache = HybridCache( + config=self.model.config.text_config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.model.device, + dtype=self.model.dtype, + ) + + # Register all key and value cache tensors as buffers + for i in range(len(self.cache.key_cache)): + self.register_buffer( + f"key_cache_{i}", self.cache.key_cache[i], persistent=False + ) + self.register_buffer( + f"value_cache_{i}", self.cache.value_cache[i], persistent=False + ) + + def forward( + self, + *, + cache_position: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`): Optional tensor representing input embeddings. + + Returns: + torch.Tensor: Logits output from the model. + """ + batch_size = ( + input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + ) + + # Generate position_ids from cache_position + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + + # Forward pass with the model + outputs = self.model( + input_ids=input_ids, + attention_mask=None, + position_ids=position_ids, + past_key_values=self.cache, + use_cache=True, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + ) + + # Return only the logits to simplify the export + return outputs.logits + + +class TorchExportableModuleForImageTextLM(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for image-text LM with cache. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + ValueError: If the model is configured with a unsupported cache implementation. + """ + super().__init__() + + if ( + not hasattr(model.config.text_config, "use_cache") + or model.config.text_config.use_cache is False + ): + raise ValueError("The model must have caching enabled to be performant.") + + if ( + hasattr(model.config.text_config, "layer_types") + and getattr(model.config.text_config, "sliding_window", None) is not None + ): + self.model = TorchExportableModuleWithHybridCache( + model, max_batch_size, max_cache_len + ) + else: + # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, + # there is only 1 type of layers, so export will use `StaticCache` by default. + raise NotImplementedError( + "Using `StaticCache` for exporting image-text LM is not implemented yet." + ) + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register( + "sdpa_without_vmap", sdpa_mask_without_vmap + ) + ALL_ATTENTION_FUNCTIONS.register( + "sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"] + ) + self.model.model.config._attn_implementation = "sdpa_without_vmap" + + def forward( + self, + input_embeds: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + input_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.model.forward(input_embeds, cache_position) + + def export( + self, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the wrapped module using `torch.export`. + + Args: + input_embeds (`Optional[torch.Tensor]`): + Tensor representing current input embeddings to the module. If not provided, a default tensor will be used. + cache_position (`Optional[torch.Tensor]`): + Tensor representing current input position in the cache. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + seq_length = 3 + + if dynamic_shapes is None: + seq_len_dim = torch.export.Dim("seq_length_dim", max=seq_length) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"cache_position": cache_position, "inputs_embeds": inputs_embeds}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + return exported_program + + +class ImageEncoderExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a vision encoder-only model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, pixel_values): + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.model.vision_tower( + pixel_values=pixel_values + ).last_hidden_state + image_features = self.model.multi_modal_projector(vision_outputs) + return image_features + + +class ImageTextToTextExportableModule(torch.nn.Module): + """ + A wrapper module designed to make an image-text-to-text model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False): + super().__init__() + self.model = model + self.config = model.config + self.use_custom_kv_cache = use_custom_kv_cache + self.use_custom_sdpa = use_custom_sdpa + self.metadata = save_config_to_constant_methods( + model.config.text_config, model.generation_config + ) + logging.info(f"Metadata to be recorded in PTE: {self.metadata}") + + def _prepare_vision_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + pixel_values (torch.Tensor): Example pixel values tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + image_size = self.config.vision_config.image_size + pixel_values = torch.rand((1, 3, image_size, image_size)) + dynamic_shapes = None + strict = False + + return pixel_values, dynamic_shapes, strict + + def _prepare_text_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + input_ids (torch.Tensor): Example input IDs tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + # Prepare inputs with dynamic shapes + seq_length = 3 # Sequence length > 1 to avoid specialization issues + example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + } + strict = parse(torch.__version__) != parse( + "2.7.0" + ) # Workaround for PyTorch bug #150994 + return example_input_ids, dynamic_shapes, strict + + def _prepare_decoder_only_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + inputs_embeds (torch.Tensor): Example input embeddings tensor. + cache_position (torch.Tensor): Example cache position tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + + # Prepare inputs with dynamic shapes + seq_length = 3 + example_inputs_embeds = torch.zeros( + (1, seq_length, self.config.text_config.hidden_size), dtype=torch.float + ) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + strict = parse(torch.__version__) != parse( + "2.7.0" + ) # Workaround for PyTorch bug #150994 + return example_inputs_embeds, example_cache_position, dynamic_shapes, strict + + def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module): + from transformers.integrations.executorch import sdpa_mask_without_vmap + from transformers.masking_utils import AttentionMaskInterface + from transformers.modeling_utils import AttentionInterface + + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache( + exportable_module + ) + if self.use_custom_sdpa: + if self.use_custom_kv_cache: + AttentionInterface.register( + "custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache + ) + AttentionMaskInterface.register( + "custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap + ) + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = ( + "custom_sdpa_ring_kv_cache" + ) + else: + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = ( + "custom_sdpa" + ) + + def export( + self, + ) -> Dict[str, ExportedProgram]: + + exportable_module = TorchExportableModuleForImageTextLM( + self.model, + max_batch_size=1, + max_cache_len=self.metadata.get("get_max_seq_len"), + ) + self._register_attention_mask_for_4_53(exportable_module) + + if self.use_custom_kv_cache: + from executorch.extension.llm.optimum.custom_kv_cache import ( + replace_with_et_custom_kv_cache, + ) + + replace_with_et_custom_kv_cache( + exportable_module.model, + self.model.config.text_config, + self.model.generation_config, + self.model.dtype, + ) + + with torch.no_grad(): + inputs_embeds, cache_position, dynamic_shapes, strict = ( + self._prepare_decoder_only_export_inputs() + ) + logging.info( + f"Exporting decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + exported_program = exportable_module.export( + inputs_embeds, cache_position, dynamic_shapes, strict + ) + # Apply RemoveTransposes pass to remove + # any back-to-back transpose ops that are not needed + # e.g. output of update_cache is transposed and + # input to custom_sdpa is transposed. + from executorch.extension.llm.export.export_passes import ( + RemoveRedundantTransposes, + ) + + mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0] + exported_program = torch.export.export( + mutated_gm, + args=(), + kwargs={ + "cache_position": cache_position, + "inputs_embeds": inputs_embeds, + }, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + + # Export token embeddings + input_ids, dynamic_shapes, strict = ( + self._prepare_text_embedding_export_inputs() + ) + logging.info( + f"Exporting token embeddings using input_ids({input_ids.shape}), dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + + token_embeddings_exported_program = torch.export.export( + exportable_module.model.model.language_model.get_input_embeddings(), + args=(input_ids,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + + # Export vision embeddings + pixel_values, dynamic_shapes, strict = ( + self._prepare_vision_embedding_export_inputs() + ) + logging.info( + f"Exporting vision embeddings using pixel_values({pixel_values.shape}), dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + # Setting the _attn_implementation to "sdpa_without_vmap" for vision encoder + exportable_module.model.model.vision_tower.config._attn_implementation = ( + "sdpa_without_vmap" + ) + vision_encoder = ImageEncoderExportableModule(exportable_module.model.model) + vision_embeddings_exported_program = torch.export.export( + vision_encoder, + args=(pixel_values,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + return { + "text_model": exported_program, + "token_embedding": token_embeddings_exported_program, + "image_encoder": vision_embeddings_exported_program, + } diff --git a/extension/llm/optimum/modeling.py b/extension/llm/optimum/modeling.py new file mode 100644 index 00000000000..f6196bf7932 --- /dev/null +++ b/extension/llm/optimum/modeling.py @@ -0,0 +1,181 @@ +import logging +from typing import Optional + +import torch +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoModelForCausalLM, PretrainedConfig, PretrainedTokenizer + +logger = logging.getLogger(__name__) + + +class ExecuTorchModelForImageTextToTextCausalLM: + """ + ExecuTorch model with an image-text-to-text causal language modeling head for inference using the ExecuTorch Runtime. + + Although the auto_model_class is `AutoModelForCausalLM` same as `ExecuTorchModelForCausalLM`, this model is specifically designed for + image-text-to-text tasks. This class provides an interface for loading, running, and generating outputs from a vision-language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForCausalLM`. + model (`ExecuTorchModule`): + The loaded ExecuTorch model. + """ + + auto_model_class = AutoModelForCausalLM + + task = "image-text-to-text" + + def __init__(self, model: "ExecuTorchModule", config: "PretrainedConfig"): + if self.__class__.auto_model_class is None: + raise ValueError( + f"Class {self.__class__.__name__} must set auto_model_class. " + f"This attribute is used to identify the corresponding AutoModel class." + ) + + self.model = model + self.config = config + + # Make sure config contains vision_config and text_config, otherwise raise an error + if not hasattr(config, "vision_config") or not hasattr(config, "text_config"): + raise ValueError( + "The configuration must contain 'vision_config' and 'text_config' attributes for image-text-to-text task." + ) + metadata = self.model.method_names() + logging.debug(f"Load all static methods: {metadata}") + if "use_kv_cache" in metadata: + self.use_kv_cache = self.model.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.model.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.model.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.model.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.model.run_method("get_bos_id")[0] + for key in ("get_eos_id", "get_eos_ids"): + if key in metadata: + self.eos_token_ids = self.model.run_method(key) + break + if "get_vocab_size" in metadata: + self.vocab_size = self.model.run_method("get_vocab_size")[0] + if "use_sdpa_with_kv_cache" in metadata: + self.use_sdpa_with_kv_cache = self.model.run_method( + "use_sdpa_with_kv_cache" + )[0] + + def forward( + self, + cache_position: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. Here we are assuming pixel_values only represent 1 image. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the model. + pixel_values (`torch.Tensor`): Tensor representing image input to the model. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + if (input_ids is None) and (pixel_values is None): + raise ValueError( + "You must specify at least one of input_ids or pixel_values" + ) + + inputs_embeds = self.model.run_method("token_embeddings", (input_ids,))[0] + + if pixel_values is not None: + image_features = self.model.run_method( + "vision_embeddings", (pixel_values,) + )[0] + + if input_ids is None: + special_image_mask = ( + inputs_embeds + == self.model.run_method( + "token_embeddings", + ( + torch.tensor( + self.config.image_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ), + ), + )[0] + ) + else: + special_image_mask = ( + input_ids == self.config.image_token_id + ).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device + ) + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) + + logits = self.model.run_method("decoder", (cache_position, inputs_embeds))[0] + return logits + + def generate( + self, + tokenizer: "PretrainedTokenizer", + input_ids: torch.LongTensor, + pixel_values: Optional[torch.FloatTensor] = None, + max_new_tokens: int = 100, + ): + # Sanity check + + if max_new_tokens <= 0: + raise ValueError( + f"max_new_tokens must be greater than 0, got {max_new_tokens}." + ) + elif max_new_tokens > self.max_cache_size: + logging.warning( + f"max_new_tokens={max_new_tokens} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_new_tokens = self.max_cache_size + + # Prefill + logits = self.forward( + input_ids=input_ids, + pixel_values=pixel_values, + cache_position=torch.arange( + input_ids.size(1), dtype=torch.long, device=input_ids.device + ), + ) + + tokens = [] + + token = torch.argmax(logits[:, -1, :], dim=-1).item() + tokens.append(token) + i = 1 + while i < max_new_tokens: + # Generate next token + logits = self.forward( + input_ids=torch.tensor( + [token], dtype=torch.long, device=input_ids.device + ).unsqueeze(0), + cache_position=torch.tensor( + [input_ids.size(1) + i - 1], + dtype=torch.long, + device=input_ids.device, + ), + ) + token = torch.argmax(logits[:, -1, :], dim=-1).item() + tokens.append(token) + + if token in self.eos_token_ids: + break + i += 1 + + return tokenizer.decode(tokens, skip_special_tokens=True) diff --git a/extension/llm/optimum/test/test_modeling_gemma3.py b/extension/llm/optimum/test/test_modeling_gemma3.py new file mode 100644 index 00000000000..7943a0194f0 --- /dev/null +++ b/extension/llm/optimum/test/test_modeling_gemma3.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.extension.llm.optimum.image_text_to_text import ( + load_image_text_to_text_model, +) +from executorch.extension.llm.optimum.modeling import ( + ExecuTorchModelForImageTextToTextCausalLM, +) +from executorch.extension.llm.optimum.xnnpack import export_to_executorch_with_xnnpack +from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def test_gemma3_image_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we( + self, + ): + + model_id = "google/gemma-3-4b-it" + + module = load_image_text_to_text_model( + model_id, + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear=True, + qembedding=True, + ) + model = export_to_executorch_with_xnnpack(module) + et_model = ExecuTorchModelForImageTextToTextCausalLM( + model, PretrainedConfig.from_pretrained(model_id) + ) + # Generate + image_url = "https://llava-vl.github.io/static/images/view.jpg" + conversation = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + { + "type": "text", + "text": "What are the things I should be cautious about when I visit here?", + }, + ], + }, + ] + processor = AutoProcessor.from_pretrained(model_id) + inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + output = et_model.generate( + AutoTokenizer.from_pretrained(model_id), + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=50, + ) + self.assertEqual( + output, + """Okay, let's analyze the image and discuss potential cautions for visiting this location. + +Based on the picture, we're looking at a serene lake scene with mountains in the background, a wooden pier, and a generally calm appearance. However""", + ) diff --git a/extension/llm/optimum/utils.py b/extension/llm/optimum/utils.py new file mode 100644 index 00000000000..88dc08adb2c --- /dev/null +++ b/extension/llm/optimum/utils.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from transformers import GenerationConfig, PretrainedConfig + + +def save_config_to_constant_methods( + config: PretrainedConfig, + generation_config: Optional[GenerationConfig] = None, + **kwargs, +): + # Initialize metadata with values from model config + head_dim = None + if ( + hasattr(config, "hidden_size") + and hasattr(config, "num_attention_heads") + and isinstance(config.num_attention_heads, int) + ): + head_dim = config.hidden_size / config.num_attention_heads + + metadata = { + "get_dtype": 5 if config.torch_dtype == torch.float16 else 6, + "get_bos_id": getattr(config, "bos_token_id", None), + "get_eos_id": getattr(config, "eos_token_id", None), + "get_head_dim": head_dim, + "get_n_kv_heads": getattr(config, "num_key_value_heads", None), + "get_n_layers": getattr(config, "num_hidden_layers", None), + "get_vocab_size": getattr(config, "vocab_size", None), + "get_max_batch_size": 1, + "get_max_seq_len": getattr(config, "max_position_embeddings", None), + "use_kv_cache": getattr(generation_config, "use_cache", None), + "sliding_window": getattr(config, "sliding_window", None), + "decoder_start_token_id": getattr(config, "decoder_start_token_id", None), + "use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation, + } + + # Safely access fields from generation_config if it exists + if generation_config is not None: + # Check for cache_config and its attributes + cache_config = getattr(generation_config, "cache_config", None) + if cache_config is not None: + max_batch_size = getattr(cache_config, "batch_size", None) + max_seq_len = getattr(cache_config, "max_cache_len", None) + + if max_batch_size is not None: + metadata["get_max_batch_size"] = max_batch_size + if max_seq_len is not None: + metadata["get_max_seq_len"] = max_seq_len + + # Combine with any additional kwargs and filter out None values + return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None} diff --git a/extension/llm/optimum/xnnpack.py b/extension/llm/optimum/xnnpack.py new file mode 100644 index 00000000000..7d7577395ed --- /dev/null +++ b/extension/llm/optimum/xnnpack.py @@ -0,0 +1,134 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +import torch + +from executorch import version as executorch_version +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + ExecutorchProgram, + to_edge_transform_and_lower, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import MemoryPlanningPass + +from executorch.extension.llm.optimum.integrations import ( + ImageTextToTextExportableModule, +) + +from packaging.version import parse +from tabulate import tabulate +from torch.export import ExportedProgram + + +class RemovePaddingIdxEmbeddingPass(ExportPass): + """ + An ExportPass that removes the `padding_idx` keyword argument + from all aten.embedding.default operator calls. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.embedding.default + ): + # node.args[2] is the padding_idx + if len(node.args) == 3: + node.args = (node.args[0], node.args[1]) + graph_module.recompile() + return PassResult(graph_module, True) + + +def export_to_executorch_with_xnnpack( + model: ImageTextToTextExportableModule, + **kwargs, +): + """ + Export a PyTorch model to ExecuTorch w/ delegation to XNNPACK backend. + + This function also write metadata required by the ExecuTorch runtime to the model. + + Args: + model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, ImageTextToTextExportableModule]): + The PyTorch model to be exported to ExecuTorch. + **kwargs: + Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs. + + Returns: + Dict[str, ExecutorchProgram]: + A map of exported and optimized program for ExecuTorch. + For encoder-decoder models or multimodal models, it may generate multiple programs. + """ + + def _lower_to_executorch( + exported_programs: Dict[str, ExportedProgram], + metadata=None, + ) -> Dict[str, ExecutorchProgram]: + backend_config_dict = { + "extract_delegate_segments": True, + "memory_planning_pass": MemoryPlanningPass(alloc_graph_input=False), + } + if parse(executorch_version.__version__).base_version > "0.6.0": + backend_config_dict["do_quant_fusion_and_const_prop"] = True + pte_name = model.model.config.model_type + logging.debug(f"\nExported program for {pte_name}.pte: {exported_programs}") + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ).to_executorch( + config=ExecutorchBackendConfig(**backend_config_dict), + ) + delegation_info = get_delegation_info( + et_prog.exported_program(list(exported_programs.keys())[0]).graph_module + ) + logging.debug( + f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}" + ) + logging.debug( + f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}" + ) + return {pte_name: et_prog} + + exported_progs = model.export() + + if ( + model.config._attn_implementation == "custom_sdpa" + or model.config._attn_implementation == "custom_sdpa_ring_kv_cache" + ): + # Sanity check to make sure the exported program contains the custom sdpa operator. + if not any( + node.op == "call_function" and "custom_sdpa" in str(node.target) + for exported_program in exported_progs.values() + for node in exported_program.graph_module.graph.nodes + ): + raise ValueError("'custom_sdpa' not found in the graph.") + + return _lower_to_executorch(exported_progs, model.metadata) diff --git a/pyproject.toml b/pyproject.toml index 40ff4eb0465..99cac4ea37f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dependencies=[ "ruamel.yaml", "sympy", "tabulate", + "transformers==4.53.1", # See also third-party/TARGETS for buck's typing-extensions version. "typing-extensions>=4.10.0", # Keep this version in sync with: ./backends/apple/coreml/scripts/install_requirements.sh diff --git a/requirements-examples.txt b/requirements-examples.txt index 0923cf8fefc..8640a024217 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -4,4 +4,3 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now timm == 1.0.7 torchsr == 1.0.4 torchtune >= 0.6.1 -transformers == 4.53.1