diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 7cfd0ac5fc6..cbdbaa7ee3b 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -797,6 +797,8 @@ jobs: --etdump_path ${OUTPUT_DIR}/etdump.etdp \ --tsv_path ${TSV_PATH} + echo "::group::Run Multimodal Tests" + python3 -m unittest extension/llm/optimum/test/test_modeling_gemma3.py echo "::endgroup::" diff --git a/extension/llm/optimum/README.md b/extension/llm/optimum/README.md new file mode 100644 index 00000000000..d9ee1cc3e32 --- /dev/null +++ b/extension/llm/optimum/README.md @@ -0,0 +1,99 @@ +# ExecuTorch Optimum Module + +This module provides ExecuTorch-specific optimizations and integrations for transformer models. It focuses on runtime-specific features that are not available in the upstream transformers or optimum-executorch libraries. + +## Overview + +This streamlined module contains only ExecuTorch-specific components: + +- Custom cache implementations optimized for ExecuTorch runtime +- Custom SDPA implementations for ExecuTorch operators +- XNNPACK backend integration and optimization passes +- ExecuTorch-specific utilities + +For general model export functionality, use `optimum-executorch` which provides a comprehensive recipe system and CLI interface. + +## Key Components + +### Custom Cache Implementations + +#### `ETCustomStaticCache` and `ETCustomHybridCache` +Custom KV cache implementations that inherit from Hugging Face's caches but use ExecuTorch's `CustomKVCache` and `CustomRingKVCache` for optimal runtime performance. + +### Custom SDPA + +#### `get_custom_sdpa_for_ring_kv_cache` +Custom Scaled Dot-Product Attention implementation optimized for ExecuTorch's ring buffer caches and sliding window attention. + +### XNNPACK Integration + +#### `export_to_executorch_with_xnnpack` +ExecuTorch-specific XNNPACK backend integration with custom optimization passes: +- `RemovePaddingIdxEmbeddingPass`: Removes padding_idx from embedding operations +- Memory planning and quantization optimizations +- Backend delegation analysis and debugging + +### Utilities + +- `save_config_to_constant_methods`: ExecuTorch-specific configuration utilities +- Model metadata extraction for runtime optimization + +## Usage + +For multimodal model export, use optimum-executorch: + +```bash +# Export with optimum-executorch CLI +optimum-cli export executorch \ + --model google/gemma-3-4b-it \ + --task image-text-to-text \ + --recipe xnnpack \ + --use_custom_sdpa \ + --use_custom_kv_cache +``` + +```python +# Or via Python API +from optimum.executorch import ExecuTorchModelForCausalLM + +model = ExecuTorchModelForCausalLM.from_pretrained( + "google/gemma-3-4b-it", + task="image-text-to-text", + recipe="xnnpack", + use_custom_sdpa=True, + use_custom_kv_cache=True +) +``` + +For ExecuTorch-specific XNNPACK optimizations: + +```python +from optimum.exporters.executorch.integrations import ImageTextToTextExportableModule +from executorch.extension.llm.optimum.xnnpack import export_to_executorch_with_xnnpack + +# Load model using optimum-executorch +module = ImageTextToTextExportableModule(model, use_custom_kv_cache=True, use_custom_sdpa=True) + +# Apply ExecuTorch-specific XNNPACK optimizations +executorch_program = export_to_executorch_with_xnnpack(module) +``` + +## Architecture + +This module follows the recommended approach: +1. **General export functionality**: Use `optimum-executorch` +2. **Multimodal support**: Enhanced `transformers.integrations.executorch` +3. **ExecuTorch-specific optimizations**: This module + +This separation ensures: +- No code duplication between repositories +- Leverages mature optimum-executorch infrastructure +- Focuses ExecuTorch module on runtime-specific optimizations +- Maintains unified user experience through optimum-executorch CLI/API + +## Testing + +Run tests with: +```bash +python -m pytest extension/llm/optimum/test/ +``` diff --git a/extension/llm/optimum/__init__.py b/extension/llm/optimum/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/extension/llm/optimum/custom_kv_cache.py b/extension/llm/optimum/custom_kv_cache.py new file mode 100644 index 00000000000..538772df43d --- /dev/null +++ b/extension/llm/optimum/custom_kv_cache.py @@ -0,0 +1,421 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch + + +# If transformers is not installed, raise an ImportError +try: + from transformers.cache_utils import HybridCache, StaticCache +except ImportError: + raise ImportError( + "transformers is not installed. Please install it to use Static/HybridCache." + ) + +try: + from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomKVCache, + CustomRingKVCache, + ) +except ImportError: + raise ImportError( + "ExecutorTorch is not installed. Please install it to use Custom Cache." + ) + + +class ETCustomStaticCache(StaticCache): + """ + Custom KV Cache implementation for ExecutorTorch that inherits from Hugging Face's StaticCache + but uses custom operations for cache updates similar to ExecutorTorch's CustomStaticCache. + """ + + def __init__( + self, + config, + max_batch_size: int, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.float32, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ): + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + ) + + # make sure layer_device_map is none + assert layer_device_map is None + assert device is None or device == "cpu", "Device must be None or 'cpu'" + + # Create a list of CustomKVCache instances, one per layer + self.kv_cache = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + layer_cache = CustomKVCache( + max_batch_size=self.max_batch_size, + max_context_length=self.max_cache_len, + n_heads=self.num_key_value_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self.kv_cache.append(layer_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` + using ExecutorTorch's CustomKVCache. + + Args: + key_states (`torch.Tensor`): + The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + value_states (`torch.Tensor`): + The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache update. + + Returns: + A tuple containing the updated key and value states. + """ + assert cache_kwargs is not None + + # Get cache position from cache_kwargs (used by StaticCache) + cache_position = cache_kwargs.get("cache_position") + assert cache_position is not None + assert isinstance(cache_position, torch.Tensor) + + # Get the CustomKVCache instance for this layer + layer_cache = self.kv_cache[layer_idx] + + # Use the CustomKVCache's update method + # CustomKVCache expects input_pos, k_val, v_val and handles the transpose internally + k_out, v_out = layer_cache.update( + input_pos=cache_position, + k_val=key_states, + v_val=value_states, + ) + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # Occupied cache == any slot in the 2nd dim (sequence length) holds a non-zero value + # This is different from StaticCache which checks the 3rd dim + if layer_idx is None: + layer_idx = 0 + return (self.kv_cache[layer_idx].k_cache[0, :, 0].any(dim=-1)).sum() + + @classmethod + def from_legacy_cache( + cls, + config, + legacy_cache, + max_cache_len=None, + device=None, + dtype=None, + ): + """ + Create an ETCustomStaticCache from a legacy cache implementation. + + Args: + config: The model configuration + legacy_cache: The legacy cache implementation + max_cache_len: The maximum cache length + device: The device for the new cache + dtype: The data type for the new cache + + Returns: + A new ETCustomStaticCache instance + """ + assert hasattr(legacy_cache, "k_cache") and hasattr(legacy_cache, "v_cache") + # Extract dimensions from the legacy cache + assert len(legacy_cache.k_cache.shape) == 4 + if legacy_cache.k_cache.shape[1] == legacy_cache.n_heads: + # Shape is [batch_size, n_heads, seq_len, head_dim] + max_batch_size = legacy_cache.k_cache.shape[0] + else: + # Shape is [batch_size, seq_len, n_heads, head_dim] + max_batch_size = legacy_cache.k_cache.shape[0] + + # Use the legacy cache's device and dtype if not specified + if device is None and hasattr(legacy_cache, "device"): + device = legacy_cache.device + elif device is None and hasattr(legacy_cache.k_cache, "device"): + device = legacy_cache.k_cache.device + + if dtype is None and hasattr(legacy_cache, "dtype"): + dtype = legacy_cache.dtype + elif dtype is None and hasattr(legacy_cache.k_cache, "dtype"): + dtype = legacy_cache.k_cache.dtype + + assert device is None or device == "cpu" + assert dtype is None or dtype == torch.float32 + + # Use the legacy cache's max_seq_len if max_cache_len is not specified + if max_cache_len is None and hasattr(legacy_cache, "max_seq_len"): + max_cache_len = legacy_cache.max_seq_len + elif max_cache_len is None and hasattr(legacy_cache, "max_cache_len"): + max_cache_len = legacy_cache.max_cache_len + + return cls( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + ) + + +# Need to figure out if I have to inherit from HybridCache or StaticCache +class ETCustomHybridCache(HybridCache): + """ + Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache + but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers. + """ + + def __init__( + self, + config, + max_batch_size: int, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.float32, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ): + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + ) + + # make sure layer_device_map is none + assert layer_device_map is None + assert device is None or device == "cpu", "Device must be None or 'cpu'" + + self.cache_position = None + # Create a list of cache instances, one per layer + # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers + self.kv_cache = torch.nn.ModuleList() + for layer_idx in range(config.num_hidden_layers): + # newer version of transfomer has is_sliding defined + # for HybridCache + if self.is_sliding[layer_idx]: + # This is a sliding window layer + layer_cache = CustomRingKVCache( + max_batch_size=self.max_batch_size, + max_context_length=self.sliding_window_len, + n_heads=self.num_key_value_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + else: + layer_cache = CustomKVCache( + max_batch_size=self.max_batch_size, + max_context_length=self.max_cache_len, + n_heads=self.num_key_value_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self.kv_cache.append(layer_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` + using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type. + + Args: + key_states (`torch.Tensor`): + The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + value_states (`torch.Tensor`): + The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache update. + + Returns: + A tuple containing the updated key and value states. + """ + assert cache_kwargs is not None + + # Get cache position from cache_kwargs (used by HybridCache) + cache_position = cache_kwargs.get("cache_position") + assert cache_position is not None + assert isinstance(cache_position, torch.Tensor) + self.cache_position = cache_position + + # Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache) + layer_cache = self.kv_cache[layer_idx] + + # Use the cache's update method + # Both CustomKVCache and CustomRingKVCache have the same update interface + k_out, v_out = layer_cache.update( + input_pos=cache_position, + k_val=key_states, + v_val=value_states, + ) + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if layer_idx is None: + layer_idx = 0 + + # For CustomRingKVCache, we need to handle the sequence length differently + layer_cache = self.kv_cache[layer_idx] + if self.is_sliding[layer_idx]: + # CustomRingKVCache cache_position_manager which + # maintains cache position for each slot in the kv cache + # we return the max position + 1 to indicate max position + # seen so far. Not sure if thats the correct interpretation + # of sequence length + return layer_cache.cache_positions_manager.cache_positions.max().item() + 1 + return (layer_cache.k_cache[0, :, 0].any(dim=-1)).sum() + + def get_layer_cache(self, layer_idx: int): + """ + Get the cache for a specific layer. This method is dynamo-traceable. + + Args: + layer_idx (int): The layer index + + Returns: + The cache instance for the specified layer (CustomKVCache or CustomRingKVCache) + """ + return self.kv_cache[layer_idx] + + +def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): + """ + Replace all KV caches in the module with ETCustomStaticCache. + This modifies the model in place. + + Args: + module: The module to modify + config: The model configuration + + Returns: + The modified module + """ + # Recursively replace KV caches + return _replace_with_et_custom_kv_cache( + module, config, generation_config, cache_dtype + ) + + +def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): + """ + Helper function to recursively replace KV caches in the module. + + Args: + module: The module to modify + config: The model configuration + + Returns: + The modified module + """ + # Check if module has static_cache (TorchExportableModuleWithStaticCache) + if hasattr(module, "static_cache"): + assert isinstance( + module.static_cache, StaticCache + ), f"Expected StaticCache, got {type(module.static_cache)}" + + # TODO: Add replace_cache to exported module + # in transformer's executorch.py + if getattr(module, "replace_cache", None) is not None: + static_cache = ETCustomStaticCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + module.replace_cache(static_cache) + else: + module.static_cache = ETCustomStaticCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + # Dont know why we need to this even though + # CustomKVCache registers the attributes + for i in range(len(module.static_cache.kv_cache)): + setattr( + module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache + ) + setattr( + module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache + ) + + # Check if module has cache (TorchExportableModuleWithHybridCache) + elif hasattr(module, "cache"): + assert isinstance( + module.cache, HybridCache + ), f"Expected HybridCache, got {type(module.cache)}" + + # Replace with ETCustomHybridCache + if getattr(module, "replace_cache", None) is not None: + hybrid_cache = ETCustomHybridCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + module.replace_cache(hybrid_cache) + else: + module.cache = ETCustomHybridCache( + config=config, + max_batch_size=generation_config.cache_config.batch_size, + max_cache_len=generation_config.cache_config.max_cache_len, + device=generation_config.cache_config.device, + dtype=cache_dtype, + ) + # Register cache attributes for each layer + for i in range(len(module.cache.kv_cache)): + setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache) + setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache) + if module.cache.is_sliding[i]: + # Register cache_positions as buffer for sliding window layers + # This prevents it from being traced as a constant + module.register_buffer( + f"cache_positions_{i}", + module.cache.kv_cache[ + i + ].cache_positions_manager.cache_positions, + persistent=False, + ) + else: + raise ValueError( + "Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) " + "or 'cache' (TorchExportableModuleWithHybridCache) attribute" + ) + + return module diff --git a/extension/llm/optimum/custom_sdpa.py b/extension/llm/optimum/custom_sdpa.py new file mode 100644 index 00000000000..eb28069d3e4 --- /dev/null +++ b/extension/llm/optimum/custom_sdpa.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable, Optional, Tuple, Union + +import torch +from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa + + +def custom_sdpa_with_start_pos_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + # FA2 uses non-transposed inputs + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Convert the hell out of the inputs to fp32 and back + input_dtype = query.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + + # Ignore the causal flag from kwargs but use the one in module + kwargs.pop("is_causal", None) + assert module.is_causal, "Current variant supports only causal attention" + + is_causal = module.is_causal + if kwargs.get("is_sliding", False): + is_causal = False + attn_mask = attention_mask + # start_pos is not important when using mask + # instead of doing causal attention + start_pos = 0 + else: + attn_mask = None + # Calculate the input pos from attention mask. + # Branch out for float vs bool mask + # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) + first_row_mask = attention_mask[0, :] + # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 + start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1 + + output = torch.ops.llama.custom_sdpa( + query, + key, + value, + start_pos=start_pos, + attn_mask=attn_mask, + drpout_p=0.0, + is_causal=is_causal, + scale=scaling, + ) + return output.to(input_dtype), None + + +def get_custom_sdpa_for_ring_kv_cache( + exportable_module: torch.nn.Module, +) -> Callable: + # lazy importing to avoid version dependent class definition + from executorch import version + + try: + from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomRingKVCache, + ) + except ImportError: + raise ImportError( + f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch." + ) + + def _custom_sdpa_for_ring_kv_cache( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + is_sliding = getattr(module, "is_sliding", False) + if is_sliding: + # lazy import to avoid being in the optimum import path + # for et <= 0.6.0 version + from optimum.executorch.attentions.custom_kv_cache import ( + ETCustomHybridCache, + ) + + layer_idx = module.layer_idx + assert ( + layer_idx is not None + ), "layer_idx is not set for sliding window attention." + hybrid_cache = exportable_module.model.cache + assert isinstance( + hybrid_cache, ETCustomHybridCache + ), f"Expected HybridCache, got {type(hybrid_cache)}" + ring_cache = hybrid_cache.get_layer_cache(layer_idx) + assert isinstance( + ring_cache, CustomRingKVCache + ), f"Expected CustomRingKVCache, got {type(ring_cache)}" + input_pos = hybrid_cache.cache_position[0].item() + seqlen = query.shape[2] + attention_mask = ring_cache.create_causal_mask_for_ring_buffer( + input_pos, seqlen + ) + kwargs.update({"is_sliding": True}) + return custom_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + scaling, + softcap, + head_mask, + **kwargs, + ) + else: + return custom_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + scaling, + softcap, + head_mask, + **kwargs, + ) + + return _custom_sdpa_for_ring_kv_cache diff --git a/extension/llm/optimum/test/test_modeling_gemma3.py b/extension/llm/optimum/test/test_modeling_gemma3.py new file mode 100644 index 00000000000..31f0a33a8fd --- /dev/null +++ b/extension/llm/optimum/test/test_modeling_gemma3.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from optimum.exporters.executorch.tasks.image_text_to_text import ( + load_image_text_to_text_model, +) +from optimum.executorch import ExecuTorchModelForCausalLM +from executorch.extension.llm.optimum.xnnpack import export_to_executorch_with_xnnpack +from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def test_gemma3_image_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we( + self, + ): + + model_id = "google/gemma-3-4b-it" + + module = load_image_text_to_text_model( + model_id, + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear=True, + qembedding=True, + ) + executorch_program = export_to_executorch_with_xnnpack(module) + + # Verify the program was created successfully + self.assertIsNotNone(executorch_program) + + # Note: For actual usage, use optimum-executorch API: + # model = ExecuTorchModelForCausalLM.from_pretrained( + # model_id, task="image-text-to-text", recipe="xnnpack", + # use_custom_sdpa=True, use_custom_kv_cache=True + # ) + # This test demonstrates ExecuTorch-specific XNNPACK optimizations diff --git a/extension/llm/optimum/utils.py b/extension/llm/optimum/utils.py new file mode 100644 index 00000000000..88dc08adb2c --- /dev/null +++ b/extension/llm/optimum/utils.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from transformers import GenerationConfig, PretrainedConfig + + +def save_config_to_constant_methods( + config: PretrainedConfig, + generation_config: Optional[GenerationConfig] = None, + **kwargs, +): + # Initialize metadata with values from model config + head_dim = None + if ( + hasattr(config, "hidden_size") + and hasattr(config, "num_attention_heads") + and isinstance(config.num_attention_heads, int) + ): + head_dim = config.hidden_size / config.num_attention_heads + + metadata = { + "get_dtype": 5 if config.torch_dtype == torch.float16 else 6, + "get_bos_id": getattr(config, "bos_token_id", None), + "get_eos_id": getattr(config, "eos_token_id", None), + "get_head_dim": head_dim, + "get_n_kv_heads": getattr(config, "num_key_value_heads", None), + "get_n_layers": getattr(config, "num_hidden_layers", None), + "get_vocab_size": getattr(config, "vocab_size", None), + "get_max_batch_size": 1, + "get_max_seq_len": getattr(config, "max_position_embeddings", None), + "use_kv_cache": getattr(generation_config, "use_cache", None), + "sliding_window": getattr(config, "sliding_window", None), + "decoder_start_token_id": getattr(config, "decoder_start_token_id", None), + "use_sdpa_with_kv_cache": "custom_sdpa" in config._attn_implementation, + } + + # Safely access fields from generation_config if it exists + if generation_config is not None: + # Check for cache_config and its attributes + cache_config = getattr(generation_config, "cache_config", None) + if cache_config is not None: + max_batch_size = getattr(cache_config, "batch_size", None) + max_seq_len = getattr(cache_config, "max_cache_len", None) + + if max_batch_size is not None: + metadata["get_max_batch_size"] = max_batch_size + if max_seq_len is not None: + metadata["get_max_seq_len"] = max_seq_len + + # Combine with any additional kwargs and filter out None values + return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None} diff --git a/extension/llm/optimum/xnnpack.py b/extension/llm/optimum/xnnpack.py new file mode 100644 index 00000000000..abd0afa036d --- /dev/null +++ b/extension/llm/optimum/xnnpack.py @@ -0,0 +1,134 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +import torch + +from executorch import version as executorch_version +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + ExecutorchProgram, + to_edge_transform_and_lower, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import MemoryPlanningPass + +from optimum.exporters.executorch.integrations import ( + ImageTextToTextExportableModule, +) + +from packaging.version import parse +from tabulate import tabulate +from torch.export import ExportedProgram + + +class RemovePaddingIdxEmbeddingPass(ExportPass): + """ + An ExportPass that removes the `padding_idx` keyword argument + from all aten.embedding.default operator calls. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.embedding.default + ): + # node.args[2] is the padding_idx + if len(node.args) == 3: + node.args = (node.args[0], node.args[1]) + graph_module.recompile() + return PassResult(graph_module, True) + + +def export_to_executorch_with_xnnpack( + model: ImageTextToTextExportableModule, + **kwargs, +): + """ + Export a PyTorch model to ExecuTorch w/ delegation to XNNPACK backend. + + This function also write metadata required by the ExecuTorch runtime to the model. + + Args: + model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, ImageTextToTextExportableModule]): + The PyTorch model to be exported to ExecuTorch. + **kwargs: + Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs. + + Returns: + Dict[str, ExecutorchProgram]: + A map of exported and optimized program for ExecuTorch. + For encoder-decoder models or multimodal models, it may generate multiple programs. + """ + + def _lower_to_executorch( + exported_programs: Dict[str, ExportedProgram], + metadata=None, + ) -> Dict[str, ExecutorchProgram]: + backend_config_dict = { + "extract_delegate_segments": True, + "memory_planning_pass": MemoryPlanningPass(alloc_graph_input=False), + } + if parse(executorch_version.__version__).base_version > "0.6.0": + backend_config_dict["do_quant_fusion_and_const_prop"] = True + pte_name = model.model.config.model_type + logging.debug(f"\nExported program for {pte_name}.pte: {exported_programs}") + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ).to_executorch( + config=ExecutorchBackendConfig(**backend_config_dict), + ) + delegation_info = get_delegation_info( + et_prog.exported_program(list(exported_programs.keys())[0]).graph_module + ) + logging.debug( + f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}" + ) + logging.debug( + f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}" + ) + return {pte_name: et_prog} + + exported_progs = model.export() + + if ( + model.config._attn_implementation == "custom_sdpa" + or model.config._attn_implementation == "custom_sdpa_ring_kv_cache" + ): + # Sanity check to make sure the exported program contains the custom sdpa operator. + if not any( + node.op == "call_function" and "custom_sdpa" in str(node.target) + for exported_program in exported_progs.values() + for node in exported_program.graph_module.graph.nodes + ): + raise ValueError("'custom_sdpa' not found in the graph.") + + return _lower_to_executorch(exported_progs, model.metadata) diff --git a/pyproject.toml b/pyproject.toml index 40ff4eb0465..99cac4ea37f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dependencies=[ "ruamel.yaml", "sympy", "tabulate", + "transformers==4.53.1", # See also third-party/TARGETS for buck's typing-extensions version. "typing-extensions>=4.10.0", # Keep this version in sync with: ./backends/apple/coreml/scripts/install_requirements.sh diff --git a/requirements-examples.txt b/requirements-examples.txt index 0923cf8fefc..8640a024217 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -4,4 +4,3 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now timm == 1.0.7 torchsr == 1.0.4 torchtune >= 0.6.1 -transformers == 4.53.1