diff --git a/optimum/executorch/__init__.py b/optimum/executorch/__init__.py index 07b72a6b..5ac795d5 100644 --- a/optimum/executorch/__init__.py +++ b/optimum/executorch/__init__.py @@ -24,6 +24,7 @@ "ExecuTorchModelForMaskedLM", "ExecuTorchModelForSeq2SeqLM", "ExecuTorchModelForSpeechSeq2Seq", + "ExecuTorchModelForMultiModalToText", ], } @@ -34,6 +35,8 @@ ExecuTorchModelForMaskedLM, ExecuTorchModelForSeq2SeqLM, ExecuTorchModelForSpeechSeq2Seq, + ExecuTorchModelForImageTextToTextCausalLM, + ExecuTorchModelForMultiModalToText, ) else: import sys diff --git a/optimum/executorch/attentions/custom_sdpa.py b/optimum/executorch/attentions/custom_sdpa.py index 59476c74..7b2e81f7 100644 --- a/optimum/executorch/attentions/custom_sdpa.py +++ b/optimum/executorch/attentions/custom_sdpa.py @@ -45,7 +45,7 @@ def custom_sdpa_with_start_pos_forward( # Ignore the causal flag from kwargs but use the one in module kwargs.pop("is_causal", None) - assert module.is_causal, "Current variant supports only causal attention" + # assert module.is_causal, "Current variant supports only causal attention" is_causal = module.is_causal if kwargs.get("is_sliding", False): @@ -56,13 +56,16 @@ def custom_sdpa_with_start_pos_forward( start_pos = 0 else: attn_mask = None - # Calculate the input pos from attention mask. - # Branch out for float vs bool mask - # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." - attention_mask = attention_mask.reshape(-1, max_seq_len) - first_row_mask = attention_mask[0, :] - # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 - start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1 + if is_causal: + # Calculate the input pos from attention mask. + # Branch out for float vs bool mask + # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) + first_row_mask = attention_mask[0, :] + # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 + start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1 + else: + start_pos = 0 output = torch.ops.llama.custom_sdpa( query, @@ -81,14 +84,19 @@ def get_custom_sdpa_for_ring_kv_cache( exportable_module: torch.nn.Module, ) -> Callable: # lazy importing to avoid version dependent class definition - from executorch import version + # try: + # from executorch import __version__ as version + # except ImportError: + # # Fallback if version is not available + # version = None try: from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( CustomRingKVCache, ) except ImportError: - raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.") + # raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.") + print() def _custom_sdpa_for_ring_kv_cache( module: torch.nn.Module, diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 28ffc2fd..a84f626b 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -30,10 +30,12 @@ AutoModelForMaskedLM, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, - PretrainedConfig, + AutoModelForMultimodalTextToText, PreTrainedTokenizer, add_start_docstrings, ) +from transformers.configuration_utils import PretrainedConfig +from transformers.tokenization_utils import PreTrainedTokenizer from transformers.utils import is_offline_mode from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch @@ -41,7 +43,7 @@ from ..exporters import TasksManager from ..exporters.executorch import main_export -from ..exporters.executorch.utils import verify_eos_tokens_in_tokenizer +from ..exporters.executorch.utils import verify_eos_tokens_in_pretrained_tokenizer from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel from ..utils.file_utils import find_files_matching_pattern from .stats import Stats @@ -237,9 +239,9 @@ def _export( **kwargs, ) -> Dict[str, "ExecuTorchModule"]: task = kwargs.pop("task", None) - if task is not None: - logger.warning(f"task was provided and set to {task} but not used, will be ignored") - inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) + # if task is not None: + # logger.warning(f"task was provided and set to {task} but not used, will be ignored") + inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) if not task else task logging.info(f"Inferred task from model class: {inferred_task}") save_dir = TemporaryDirectory() @@ -524,7 +526,7 @@ def generate( def text_generation( self, - tokenizer: "PreTrainedTokenizer", + tokenizer: PreTrainedTokenizer, prompt: str, echo: bool = True, max_seq_len: Optional[int] = None, @@ -744,7 +746,7 @@ def generate( def text_generation( self, - tokenizer: "PreTrainedTokenizer", + tokenizer: PreTrainedTokenizer, prompt: str, echo: bool = True, max_seq_len: Optional[int] = None, @@ -771,7 +773,7 @@ def text_generation( raise ValueError( f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." ) - if not verify_eos_tokens_in_tokenizer(self.eos_token_ids, self.tokenizer): + if not verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_ids, self.tokenizer): raise ValueError( f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}." ) @@ -1065,7 +1067,7 @@ def generate( def transcribe( self, - tokenizer: "PreTrainedTokenizer", + tokenizer: PreTrainedTokenizer, input_features: torch.Tensor, echo: bool = True, max_seq_len: Optional[int] = None, @@ -1098,3 +1100,305 @@ def transcribe( self.stats.on_inference_end() self.stats.print_report() return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) + + +class ExecuTorchModelForImageTextToTextCausalLM(ExecuTorchModelBase): + """ + ExecuTorch model with an image-text-to-text causal language modeling head for inference using the ExecuTorch Runtime. + + Although the auto_model_class is `AutoModelForCausalLM` same as `ExecuTorchModelForCausalLM`, this model is specifically designed for + image-text-to-text tasks. This class provides an interface for loading, running, and generating outputs from a vision-language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForCausalLM`. + model (`ExecuTorchModule`): + The loaded ExecuTorch model. + """ + + auto_model_class = AutoModelForCausalLM + + def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"): + super().__init__(models, config) + if not hasattr(self, "model"): + raise AttributeError("Expected attribute 'model' not found in the instance.") + + # Make sure config contains vision_config and text_config, otherwise raise an error + if not hasattr(config, "vision_config") or not hasattr(config, "text_config"): + raise ValueError( + "The configuration must contain 'vision_config' and 'text_config' attributes for image-text-to-text task." + ) + metadata = self.model.method_names() + logging.debug(f"Load all static methods: {metadata}") + if "use_kv_cache" in metadata: + self.use_kv_cache = self.model.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.model.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.model.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.model.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.model.run_method("get_bos_id")[0] + for key in ("get_eos_id", "get_eos_ids"): + if key in metadata: + self.eos_token_ids = self.model.run_method(key) + break + if "get_vocab_size" in metadata: + self.vocab_size = self.model.run_method("get_vocab_size")[0] + if "use_sdpa_with_kv_cache" in metadata: + self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0] + + def forward( + self, + input_ids: Optional[torch.LongTensor], + pixel_values: Optional[torch.FloatTensor], + inputs_embeds: Optional[torch.FloatTensor], + cache_position: torch.LongTensor, + ) -> torch.Tensor: + """ + Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. Here we are assuming pixel_values only represent 1 image. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the model. + pixel_values (`torch.Tensor`): Tensor representing image input to the model. + inputs_embeds (`torch.Tensor`): Tensor representing input embeddings to the model. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + self.stats.on_model_execution_start() + + if inputs_embeds is None: + inputs_embeds = self.model.run_method("text_embeddings")(input_ids) + + if pixel_values is not None: + image_features = self.model.run_method("vision_embeddings")(pixel_values) if pixel_values is not None else None + + if input_ids is None: + special_image_mask = inputs_embeds == self.model.run_method("text_embeddings")( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + logits = self.model.run_method("decoder")( + (inputs_embeds, cache_position) + )[0] + self.stats.on_model_execution_end() + return logits + + def generate( + self, + tokenizer: PreTrainedTokenizer, + input_ids: torch.LongTensor, + pixel_values: Optional[torch.FloatTensor] = None, + max_new_tokens: int = 100, + ): + return 420 + + # Prefill + +class ExecuTorchModelForMultiModalToText(ExecuTorchModelBase): + """ + An ExecuTorch model for inference of multimodal input to text models using the ExecuTorch Runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForSpeechSeq2Seq`. + model (`ExecuTorchModule`): + The loaded ExecuTorch model. + use_kv_cache (`bool`): + Whether key-value caching is enabled. For performance reasons, the exported model is + optimized to use a static cache. + max_cache_size (`int`): + Maximum sequence length supported by the cache. + max_batch_size (`int`): + Maximum supported batch size. + dtype (`str`): + Data type of the model parameters. + bos_token_id (`int`): + Beginning-of-sequence token ID. + eos_token_id (`int`): + End-of-sequence token ID. + vocab_size (`int`): + Size of the model vocabulary. + """ + + auto_model_class = AutoModelForMultimodalTextToText + # auto_model_class = AutoModel + + def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"): + super().__init__(models=models, config=config) + # if not hasattr(self, "decoder"): + # raise AttributeError("Expected attribute 'decoder' not found in the instance.") + # if not hasattr(self, "token_embeddings"): + # raise AttributeError("Expected attribute 'token_embeddings' not found in the instance.") + # if not hasattr(self, "audio_encoder"): + # raise AttributeError("Expected attribute 'audio_encoder' not found in the instance.") + + # required_methods = ["decoder", "token_embeddings", "audio_encoder"] + # for required_method in required_methods: + # if required_method not in self.model.method_names(): + # raise ValueError("Exported .pte file needs to containt 'decoder', 'token_embeddings', and 'audio_encoder' methods.") + + metadata = self.model.method_names() + if "use_kv_cache" in metadata: + self.use_kv_cache = self.model.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.model.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.model.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.model.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.model.run_method("get_bos_id")[0] + if "get_eos_id" in metadata: + self.eos_token_id = self.model.run_method("get_eos_id")[0] + if "get_vocab_size" in metadata: + self.vocab_size = self.model.run_method("get_vocab_size")[0] + if "max_hidden_seq_length" in metadata: + self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0] + if "decoder_start_token_id" in metadata: + self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0] + + def forward( + self, + input_ids: torch.Tensor, + cache_position: torch.Tensor, + input_features: Optional[torch.Tensor] = None, + ): + token_embeddings = self.token_embeddings.forward(input_ids) + if input_features: + token_embeddings = self.audio_encoder.forward( + input_features, + token_embeddings, + input_ids, + ) + output = self.decoder.forward( + token_embeddings, + cache_position, + ) + return output + + def generate( + self, + prompt_tokens: torch.Tensor, + echo: bool = False, + pos_base: int = 0, + max_seq_len: Optional[int] = None, + input_features: Optional[torch.Tensor] = None, + ) -> List[int]: + self.device = torch.device("cpu") + if max_seq_len is None: + # Default to max_cache_size if max_seq_len is not specified + max_seq_len = self.max_cache_size + elif max_seq_len > self.max_cache_size: + logging.warning( + f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_seq_len = self.max_cache_size + + # Prefill. + self.stats.on_sampling_begin() + logits = self.forward( + input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device), + cache_position=torch.arange(len(prompt_tokens[0]), dtype=torch.long, device=self.device), + input_features=input_features, + ) + self.stats.on_sampling_end() + self.stats.on_prompt_eval_end() + + next_token = torch.argmax(logits[:, -1, :], dim=-1).item() + generated_tokens = [next_token] + print(self.tokenizer.decode([next_token]), end="") + + # Token-by-token generation. + first_token_generated = False + while len(generated_tokens) + len(prompt_tokens) < max_seq_len: + self.stats.on_sampling_begin() + logits = self.forward( + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor( + [pos_base + len(generated_tokens) + len(prompt_tokens) - 1], + dtype=torch.long, + device=self.device, + ), + ) + self.stats.on_sampling_end() + if not first_token_generated: + self.stats.on_first_token() + first_token_generated = True + + next_token = torch.argmax(logits[:, -1, :], dim=-1).item() + generated_tokens.append(next_token) + print(self.tokenizer.decode([next_token]), end="") + + if next_token == self.eos_token_id: + break + + self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens)) + return generated_tokens if echo else generated_tokens[len(prompt_tokens) :] + + def text_generation( + self, + processor: "ProcessorMixin", + tokenizer: PreTrainedTokenizer, + input_conversation: List[Dict], + echo: bool = True, + max_seq_len: Optional[int] = None, + ): + """ + Perform text generation task for a given prompt using the ExecuTorch model. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used to encode and decode the prompt and output. + prompt (`str`): + The text prompt to complete. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `True`. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + """ + self.tokenizer = tokenizer + + # Sanity check + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + raise ValueError( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + ) + if isinstance(self.tokenizer, PreTrainedTokenizer) and verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_id, self.tokenizer): + raise ValueError( + f"The tokenizer's eos_token_id does not match with the model's eos_token_id={self.eos_token_id}." + ) + + # Reset stats for a new generation + self.stats.reset() + self.stats.on_inference_start() + + inputs = processor.apply_chat_template(input_conversation) + self.stats.on_token_encode_end() + self.stats.set_num_prompt_tokens(len(inputs["input_ids"][0])) + + generated_tokens = self.generate( + prompt_tokens=inputs["input_ids"], + input_features=inputs["input_features"], + echo=echo, + max_seq_len=max_seq_len, + ) + + self.stats.on_inference_end() + self.stats.print_report() + + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 23e6819a..f7adb422 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict +from typing import Dict, Optional import torch from packaging.version import parse @@ -25,14 +25,703 @@ StaticCache, T5ForConditionalGeneration, WhisperForConditionalGeneration, + VoxtralForConditionalGeneration, + Gemma3ForConditionalGeneration, ) +from transformers.configuration_utils import PretrainedConfig from transformers.generation.configuration_utils import GenerationConfig +from transformers.cache_utils import HybridCache +from transformers.integrations.executorch import sdpa_mask_without_vmap +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache from optimum.utils.import_utils import is_transformers_version from .utils import save_config_to_constant_methods +# TODO(JZ): upstream changes here to transformers. +class TorchExportableModuleWithStaticCache(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for decoder-only LM to `StaticCache`. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + + Note: + This class is specifically designed to support export process using `torch.export` + in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + config: PretrainedConfig, + generation_config: GenerationConfig, + ): + """ + Initializes the wrapper module with the pretrained model. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching + enabled and use a 'static' caching implementation. + + Raises: + AssertionError: If the pretrained model does not have caching enabled or if it does + not use a 'static' caching implementation in `model.generation_config`. + """ + super().__init__() + + # Sanity checks + if generation_config is None: + raise AssertionError( + "The model must have a generation config to be exported with static caching. " + "Please set `generation_config`." + ) + + if not generation_config.use_cache: + raise AssertionError( + "The model must have caching enabled to be exported with static caching. " + "Please set `generation_config.use_cache=True`." + ) + + if generation_config.cache_implementation != "static": + raise AssertionError( + "The model must use a 'static' caching implementation to be exported with static caching. " + "Please set `generation_config.cache_implementation='static'`." + ) + + self.model = model + self.config = config + self.generation_config = generation_config + self.static_cache = StaticCache( + config=config, + max_batch_size=self.generation_config.cache_config.batch_size, + max_cache_len=self.generation_config.cache_config.max_cache_len, + device=self.generation_config.cache_config.device, + dtype=self.model.dtype, + ) + # TODO(JZ): figure out why len(self.static_cache) doesn't work like it does in upstream. + for i in range(len(self.static_cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + + def forward( + self, + *, + cache_position: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """ + Forward pass of the module, which is compatible with the ExecuTorch runtime. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + + This forward adapter serves two primary purposes: + + 1. **Making the Model `torch.export`-Compatible**: + The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs, + enabling the model to be exportable using `torch.export` without encountering issues. + + 2. **Ensuring Compatibility with `ExecuTorch` runtime**: + The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`, + ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. + """ + if input_ids is not None: + _, seqlen = input_ids.shape + else: + _, seqlen, _ = inputs_embeds.shape + position_ids = cache_position.unsqueeze(0) + past_key_values = self.static_cache + + outs = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=None, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=True, + ) + return outs.logits + + @staticmethod + def generate( + exported_program: torch.export.ExportedProgram, + prompt_token_ids: torch.Tensor, + max_new_tokens: int, + ) -> torch.Tensor: + """ + Generate a sequence of tokens using an exported program. + + This util function is designed to test exported models by simulating the generation process. + It processes the input prompt tokens sequentially (no parallel prefill). + This generate function is not intended to replace the original `generate` method, and the support + for leveraging the original `generate` is potentially planned! + + Args: + exported_program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. + prompt_token_ids (`torch.Tensor`): Tensor representing the input prompt token IDs. + max_new_tokens (`int`): Maximum number of new tokens to generate. Note that the total generation + length is limited by both `max_new_tokens` and the model's cache size. + + Returns: + torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens. + """ + device = prompt_token_ids.device + prompt_token_len = prompt_token_ids.shape[-1] + max_generation_length = prompt_token_len + max_new_tokens + for buffer_name, buffer in exported_program.named_buffers(): + if buffer_name.startswith("key_cache"): + max_cache_len = buffer.shape[2] + max_generation_length = min(max_generation_length, max_cache_len) + break + + response_tokens = [] + for input_pos in range(min(max_generation_length, prompt_token_len)): + result = exported_program.module().forward( + input_ids=prompt_token_ids[:, input_pos : input_pos + 1], + cache_position=torch.tensor([input_pos], dtype=torch.long, device=device), + ) + response_tokens.append(prompt_token_ids[0][input_pos].item()) + + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + response_tokens.append(current_token) + + while len(response_tokens) < max_generation_length: + result = exported_program.module().forward( + input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device), + cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device), + ) + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + response_tokens.append(current_token) + + return torch.tensor([response_tokens], dtype=torch.long, device=device) + + +class TorchExportableModuleWithHybridCache(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for decoder-only LM to `HybridCache`. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + AssertionError: If the model doesn't have the expected configuration for HybridCache. + """ + super().__init__() + self.model = model + + # Verify the model is configured for HybridCache + if not self.model.config.text_config.use_cache: + raise AssertionError("Model must have caching enabled") + + # Initialize the HybridCache + self.cache = HybridCache( + config=self.model.config.text_config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.model.device, + dtype=self.model.dtype, + ) + + # Register all key and value cache tensors as buffers + for i in range(len(self.cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) + + def forward( + self, + *, + cache_position: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`): Optional tensor representing input embeddings. + + Returns: + torch.Tensor: Logits output from the model. + """ + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + + # Generate position_ids from cache_position + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + + # Forward pass with the model + outputs = self.model( + input_ids=input_ids, + attention_mask=None, + position_ids=position_ids, + past_key_values=self.cache, + use_cache=True, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + ) + + # Return only the logits to simplify the export + return outputs.logits + + +class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for image-text LM with cache. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + config: PretrainedConfig, + generation_config: GenerationConfig, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module with `HybridCache`. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + ValueError: If the model is configured with a unsupported cache implementation. + """ + super().__init__() + + if not hasattr(config, "use_cache") or config.use_cache is False: + raise ValueError("The model must have caching enabled to be performant.") + + if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + else: + # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, + # there is only 1 type of layers, so export will use `StaticCache` by default. + logging.info( + "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." + ) + self.model = TorchExportableModuleWithStaticCache(model, config, generation_config) + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap" + + def forward( + self, + *, + cache_position: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + input_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + input_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.model.forward( + cache_position=cache_position, + input_ids=input_ids, + inputs_embeds=input_embeds, + ) + + def export( + self, + *, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the wrapped module using `torch.export`. + + Args: + input_ids (`Optional[torch.Tensor]`): + Tensor representing current input token id to the module. If not provided, a default tensor will be used. + input_embeds (`Optional[torch.Tensor]`): + Tensor representing current input embeddings to the module. If not provided, a default tensor will be used. + cache_position (`Optional[torch.Tensor]`): + Tensor representing current input position in the cache. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + if hasattr(self.model, "base_model_prefix"): + base = getattr(self.model, self.model.base_model_prefix, self.model) + model_device = base.device + elif hasattr(self.model, "model"): + model_device = self.model.model.device + else: + model_device = "cpu" + logging.warning( + "TorchExportableModuleForImageTextLM.export Can't infer device from the model. Set to CPU by default." + ) + + if not ((input_ids is None) ^ (inputs_embeds is None)): + raise ValueError("Must specify either input_ids or inputs_embeds") + + if input_ids: + example_input_ids = ( + input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device) + ) + example_cache_position = ( + cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) + ) + exported_program = torch.export.export( + self.model, + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + else: + seq_length = 3 + # TODO(JZ): remove this and pass this in instead? + # if dynamic_shapes is None: + # seq_len_dim = torch.export.Dim("seq_length_dim", max=seq_length) + # dynamic_shapes = { + # "inputs_embeds": {1: seq_len_dim}, + # "cache_position": {0: seq_len_dim}, + # } + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"cache_position": cache_position, "inputs_embeds": inputs_embeds}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + + return exported_program + + +class ImageEncoderExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a vision encoder-only model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, pixel_values): + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.model.multi_modal_projector(vision_outputs) + return image_features + + +class VoxtralEncoderExportableModule(torch.nn.Module): + """ + Subgraph which handles all of the audio-related work: encoder, multimodal projection, combinining with text tokens. + The result of this subgraph should stream directly into the decoder subgraph. + """ + def __init__(self, model: torch.nn.Module): + super().__init__() + self.audio_encoder = model.audio_tower + self.mm_projector = model.multi_modal_projector + self.intermediate_size = model.config.audio_config.intermediate_size + self.audio_token_id = model.config.audio_token_id + + def forward( + self, + input_features: torch.FloatTensor, + inputs_embeds: torch.FloatTensor, + input_ids: torch.LongTensor, + ): + audio_outputs = self.audio_encoder(input_features) + audio_hidden_states = audio_outputs.last_hidden_state + audio_hidden_states = audio_hidden_states.reshape(-1, self.intermediate_size) + audio_embeds = self.mm_projector(audio_hidden_states) + + audio_token_mask = input_ids == self.audio_token_id + inputs_embeds[audio_token_mask] = audio_embeds + + return inputs_embeds + + +class MultiModalTextToTextExportableModule(torch.nn.Module): + """ + A wrapper module designed to make an multimodal model, e.g. image-text-to-text, exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False): + super().__init__() + self.model = model + self.config = model.config + self.use_custom_kv_cache = use_custom_kv_cache + self.use_custom_sdpa = use_custom_sdpa + self.metadata = save_config_to_constant_methods(model.config.text_config, model.generation_config) + logging.info(f"Metadata to be recorded in PTE: {self.metadata}") + + def _prepare_vision_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + pixel_values (torch.Tensor): Example pixel values tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + image_size = self.config.vision_config.image_size + pixel_values = torch.rand((1, 3, image_size, image_size)) + dynamic_shapes = None + strict = False + + return pixel_values, dynamic_shapes, strict + + def _prepare_audio_embedding_export_inputs(self): + # TODO(JZ): specific to Voxtral, should generalize. + batch_size = 3 + chunk_length = self.model.audio_tower.config.max_source_positions * self.model.audio_tower.conv1.stride[0] * self.model.audio_tower.conv2.stride[0] + spectrogram_features = 128 + audio_input = torch.rand(batch_size, spectrogram_features, chunk_length) + + max_audio_len = 150 # In s, should be a multiple of 30. + dynamic_shapes = { + "input_features": { + 0: torch.export.Dim("batch_size", min=1, max=max_audio_len/30), + 1: torch.export.Dim.STATIC, + 2: torch.export.Dim.STATIC, + }, + } + + strict = True + + return audio_input, dynamic_shapes, strict + + def _prepare_text_embedding_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + input_ids (torch.Tensor): Example input IDs tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + # Prepare inputs with dynamic shapes + seq_length = 3 # Sequence length > 1 to avoid specialization issues + example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "input": {1: seq_len_dim}, + } # nn.embedding forward() here - https://github.com/pytorch/pytorch/blob/febf3c475e6fe369b41ef009f3598659a6df0911/torch/nn/modules/sparse.py#L15. + strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994 + return example_input_ids, dynamic_shapes, strict + + def _prepare_decoder_only_export_inputs(self): + """ + Prepare example inputs and configurations for export. + + Returns: + inputs_embeds (torch.Tensor): Example input embeddings tensor. + cache_position (torch.Tensor): Example cache position tensor. + dynamic_shapes (dict or None): Dynamic shape specifications for export. + strict (bool): Whether to use strict export mode. + """ + + # Prepare inputs with dynamic shapes + seq_length = 3 + example_inputs_embeds = torch.zeros((1, seq_length, self.config.text_config.hidden_size), dtype=torch.float) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window", float("inf")) + max_dim = min(max_seq_len, sliding_window) - 1 + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994 + return example_inputs_embeds, example_cache_position, dynamic_shapes, strict + + + def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module): + if is_transformers_version(">=", "4.53.0.dev0"): + from transformers.integrations.executorch import sdpa_mask_without_vmap + from transformers.masking_utils import AttentionMaskInterface + from transformers.modeling_utils import AttentionInterface + + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) + if self.use_custom_sdpa: + if self.use_custom_kv_cache: + AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) + AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" + else: + # Manually set the attention implementation to custom_sdpa_ring_kv_cache + # This handles both regular sdpa and one for sliding window/local attention + exportable_module.model.model.config._attn_implementation = "custom_sdpa" + + def export( + self, + ) -> Dict[str, ExportedProgram]: + with torch.no_grad(): + # 1. Export text decoder. + exportable_module = TorchExportableModuleForDecoderOnlyLM( + self.model.language_model, + self.config.text_config, + self.model.generation_config, + max_batch_size=1, + max_cache_len=self.metadata.get("get_max_seq_len"), + ) + exported_programs = {} + + # Custom SDPA for text decoder. + self._register_attention_mask_for_4_53(exportable_module) + + if self.use_custom_kv_cache: + from optimum.executorch.attentions.custom_kv_cache import ( + replace_with_et_custom_kv_cache, + ) + + replace_with_et_custom_kv_cache( + exportable_module.model, + self.model.config.text_config, + self.model.generation_config, + self.model.dtype, + ) + + inputs_embeds, cache_position, dynamic_shapes, strict = self._prepare_decoder_only_export_inputs() + logging.info( + f"Exporting decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}" + ) + exported_program = exportable_module.export( + inputs_embeds=inputs_embeds, + cache_position=cache_position, + dynamic_shapes=dynamic_shapes, + strict=strict + ) + # Apply RemoveTransposes pass to remove + # any back-to-back transpose ops that are not needed + # e.g. output of update_cache is transposed and + # input to custom_sdpa is transposed. + from executorch.extension.llm.export.export_passes import ( + RemoveRedundantTransposes, + ) + + mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0] + exported_program = torch.export.export( + mutated_gm, + args=(), + kwargs={"cache_position": cache_position, "inputs_embeds": inputs_embeds}, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + exported_programs["decoder"] = exported_program + + # 2. Export token embeddings + input_ids, dynamic_shapes, strict = self._prepare_text_embedding_export_inputs() + logging.info(f"Exporting token embeddings using input_ids({input_ids.shape}), dynamic_shapes={dynamic_shapes}, strict={strict}") + + token_embeddings_exported_program = torch.export.export( + self.model.language_model.get_input_embeddings(), + args=(input_ids,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + exported_programs["token_embeddings"] = token_embeddings_exported_program + + # 3. Export encoder. + input_ids = torch.zeros_like(inputs_embeds[:, :, 0], dtype=torch.long) + input_ids[0, 1] = self.config.audio_token_id # Make sure we don't have an all-false mask for the imput_embeds. + if isinstance(self.model, VoxtralForConditionalGeneration): + # TODO(JZ): specific to Voxtral, should generalize. + chunk_length = self.model.audio_tower.config.max_source_positions * self.model.audio_tower.conv1.stride[0] * self.model.audio_tower.conv2.stride[0] + encoder_input_kwargs = { + "input_features": torch.rand(3, 128, chunk_length), # (bsz, features, seq_len) + "inputs_embeds": inputs_embeds, + "input_ids": input_ids, + } + + max_audio_len = 150 # In s, should be a multiple of 30. TODO(JZ): make this configurable top-level. + max_seq_len = self.metadata.get("get_max_seq_len") + dynamic_shapes = { + "input_features": { + 0: torch.export.Dim("enc_batch_size_dim", min=1, max=max_audio_len//30), + # 1: torch.export.Dim.STATIC, + # 2: torch.export.Dim.STATIC, + }, + "inputs_embeds": {1: torch.export.Dim("input_embeds_seq_length_dim", max=max_seq_len)}, + "input_ids": {1: torch.export.Dim("input_ids_seq_length_dim", max=max_seq_len)}, + } + + # self.model.audio_tower.config._attn_implementation = "sdpa_without_vmap" + self.model.audio_tower.config._attn_implementation = "custom_sdpa" + audio_encoder = VoxtralEncoderExportableModule(self.model) + audio_encoder_exported_program = torch.export.export( + audio_encoder, + args=(), + kwargs=encoder_input_kwargs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + exported_programs["audio_encoder"] = audio_encoder_exported_program + elif isinstance(self.model, Gemma3ForConditionalGeneration): + pixel_values, dynamic_shapes, strict = self._prepare_vision_embedding_export_inputs() + logging.info(f"Exporting vision embeddings using pixel_values({pixel_values.shape}), dynamic_shapes={dynamic_shapes}, strict={strict}") + # Setting the _attn_implementation to "sdpa_without_vmap" for vision encoder + exportable_module.model.model.vision_tower.config._attn_implementation = "sdpa_without_vmap" + vision_encoder = ImageEncoderExportableModule(exportable_module.model.model) + vision_embeddings_exported_program = torch.export.export( + vision_encoder, + args=(pixel_values,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict, + ) + exported_programs["vision_encoder"] = vision_embeddings_exported_program + + return exported_programs + class CausalLMExportableModule(torch.nn.Module): """ diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index 0bd7b374..29f7015e 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -19,7 +19,8 @@ from tabulate import tabulate from torch.export import ExportedProgram -from executorch import version as executorch_version +# from executorch import version as executorch_version +from executorch import __version__ as executorch_version from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import ( @@ -28,19 +29,22 @@ ExecutorchProgram, to_edge_transform_and_lower, ) +from executorch.exir.passes import MemoryPlanningPass + from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass from ..integrations import ( CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, + MultiModalTextToTextExportableModule ) from ..recipe_registry import register_recipe @register_recipe("xnnpack") def export_to_executorch_with_xnnpack( - model: Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule], + model: Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, MultiModalTextToTextExportableModule], **kwargs, ): """ @@ -49,7 +53,7 @@ def export_to_executorch_with_xnnpack( This function also write metadata required by the ExecuTorch runtime to the model. Args: - model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule]): + model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, MultiModalTextToTextExportableModule]): The PyTorch model to be exported to ExecuTorch. **kwargs: Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs. @@ -64,36 +68,38 @@ def _lower_to_executorch( exported_programs: Dict[str, ExportedProgram], metadata=None, ) -> Dict[str, ExecutorchProgram]: - et_progs = {} backend_config_dict = { "extract_delegate_segments": True, + "memory_planning_pass": MemoryPlanningPass(alloc_graph_input=False), } if parse(executorch_version.__version__).base_version > "0.6.0": backend_config_dict["do_quant_fusion_and_const_prop"] = True - - for pte_name, exported_program in exported_programs.items(): - logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}") - et_progs[pte_name] = to_edge_transform_and_lower( - exported_program, - partitioner=[XnnpackPartitioner()], - compile_config=EdgeCompileConfig( - _check_ir_validity=False, - _skip_dim_order=True, - ), - constant_methods=metadata, - transform_passes=[RemovePaddingIdxEmbeddingPass()], - ).to_executorch( - config=ExecutorchBackendConfig(**backend_config_dict), - ) + pte_name = model.model.config.model_type + logging.debug(f"\nExported program for {pte_name}.pte: {exported_programs}") + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ) + et_prog = et_prog.to_executorch( + config=ExecutorchBackendConfig(**backend_config_dict), + ) + for method in et_prog.methods: + logging.debug(f"---------------------- Method: {method} ----------------------") logging.debug( - f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}" + f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program(method).graph_module}" ) - delegation_info = get_delegation_info(et_progs[pte_name].exported_program().graph_module) + delegation_info = get_delegation_info(et_prog.exported_program(method).graph_module) logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}") logging.debug( f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}" ) - return et_progs + return {pte_name: et_prog} exported_progs = model.export() diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py index 0f7c3be3..aa065b7c 100644 --- a/optimum/exporters/executorch/tasks/__init__.py +++ b/optimum/exporters/executorch/tasks/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import causal_lm, image_classification, masked_lm, seq2seq_lm +from . import causal_lm, image_classification, masked_lm, seq2seq_lm, multimodal_text_to_text diff --git a/optimum/exporters/executorch/tasks/image_text_to_text.py b/optimum/exporters/executorch/tasks/image_text_to_text.py new file mode 100644 index 00000000..6ae6da0e --- /dev/null +++ b/optimum/exporters/executorch/tasks/image_text_to_text.py @@ -0,0 +1,154 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +import torchao +from packaging.version import parse +from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig + +from ..integrations import MultiModalTextToTextExportableModule +from ..task_registry import register_task + + +# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +@register_task("image-text-to-text") +@register_task("audio-text-to-text") +def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for multimodal generation (e.g. image-to-text) generation and registers it under the appropriate task + (e.g. 'image-text-to-text') using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="google/gemma-3-4b-it"` or `model_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + MultiModalTextToTextExportableModule: + An instance of `MultiModalTextToTextExportableModule` for exporting and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + use_custom_sdpa = kwargs.get("use_custom_sdpa", False) + use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False) + attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa" + max_length = kwargs.get("max_length", 2048) + config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path) + + # # Make sure config has text_config and vision_config: + # if not hasattr(config, "text_config") or not hasattr(config, "vision_config"): + # raise ValueError( + # f"The model {model_name_or_path} does not have a `text_config` or `vision_config` attribute in its config. " + # "This is required for image-text-to-text models." + # ) + + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting + # the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite + # that function to avoid the data-dependent control flow. + config.rope_scaling["type"] = "default" + + if hasattr(config, "use_cache") and config.use_cache is False: + config.use_cache = True + + eager_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + config=config, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) + + # # Make sure model has language_model as well as vision_tower: + # if not hasattr(eager_model, "language_model") or not hasattr(eager_model, "vision_tower"): + # raise ValueError( + # f"The model {model_name_or_path} does not have a `language_model` or `vision_tower` attribute. " + # "This is required for image-text-to-text models." + # ) + + for param in eager_model.parameters(): + # Must disable gradient for quantized checkpoint + if isinstance(param, torchao.utils.TorchAOBaseTensor): + param.requires_grad = False + + # TODO: Move quantization recipe out for better composability. + # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed. + qlinear_config = kwargs.get("qlinear", None) + qembedding_config = kwargs.get("qembedding", None) + if qlinear_config or qembedding_config: + # TODO: Update torchao to use 0.11.0 once released + if parse(torchao.__version__) < parse("0.11.0.dev0"): + raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.") + + from torchao.quantization.granularity import PerAxis, PerGroup + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, + ) + from torchao.utils import unwrap_tensor_subclass + + if qembedding_config: + logging.info("Quantizing embedding layers.") + # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available. + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + quantize_( + eager_model, + embedding_config, + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + if qlinear_config: + logging.info("Quantizing linear layers.") + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + ) + quantize_( + eager_model.language_model, + linear_config, + ) + + unwrap_tensor_subclass(eager_model) + + return MultiModalTextToTextExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa) diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py new file mode 100644 index 00000000..4e81e463 --- /dev/null +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -0,0 +1,155 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +import torchao +from packaging.version import parse +from transformers import AutoConfig, AutoModelForMultimodalTextToText, GenerationConfig + +from ..integrations import MultiModalTextToTextExportableModule +from ..task_registry import register_task + + +# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +@register_task("image-text-to-text") +@register_task("audio-text-to-text") +@register_task("multimodal-text-to-text") +def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for multimodal generation (e.g. image-to-text) generation and registers it under the appropriate task + (e.g. 'image-text-to-text') using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="google/gemma-3-4b-it"` or `model_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + MultiModalTextToTextExportableModule: + An instance of `MultiModalTextToTextExportableModule` for exporting and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + use_custom_sdpa = kwargs.get("use_custom_sdpa", False) + use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False) + attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa" + qlinear_config = kwargs.get("qlinear", None) + qembedding_config = kwargs.get("qembedding", None) + max_length = kwargs.get("max_length", 2048) + config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path) + + # # Make sure config has text_config and vision_config: + # if not hasattr(config, "text_config") or not hasattr(config, "vision_config"): + # raise ValueError( + # f"The model {model_name_or_path} does not have a `text_config` or `vision_config` attribute in its config. " + # "This is required for image-text-to-text models." + # ) + + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting + # the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite + # that function to avoid the data-dependent control flow. + config.rope_scaling["type"] = "default" + + if hasattr(config, "use_cache") and config.use_cache is False: + config.use_cache = True + + eager_model = AutoModelForMultimodalTextToText.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + config=config, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) + + # # Make sure model has language_model as well as vision_tower: + # if not hasattr(eager_model, "language_model") or not hasattr(eager_model, "vision_tower"): + # raise ValueError( + # f"The model {model_name_or_path} does not have a `language_model` or `vision_tower` attribute. " + # "This is required for image-text-to-text models." + # ) + + for param in eager_model.parameters(): + # Must disable gradient for quantized checkpoint + if isinstance(param, torchao.utils.TorchAOBaseTensor): + param.requires_grad = False + + # TODO: Move quantization recipe out for better composability. + # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed. + if qlinear_config or qembedding_config: + # TODO: Update torchao to use 0.11.0 once released + if parse(torchao.__version__) < parse("0.11.0.dev0"): + raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.") + + from torchao.quantization.granularity import PerAxis, PerGroup + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, + ) + from torchao.utils import unwrap_tensor_subclass + + if qembedding_config: + logging.info("Quantizing embedding layers.") + # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available. + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + quantize_( + eager_model, + embedding_config, + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + if qlinear_config: + logging.info("Quantizing linear layers.") + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + ) + quantize_( + eager_model.language_model, + linear_config, + ) + + unwrap_tensor_subclass(eager_model) + + return MultiModalTextToTextExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa) diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index 70447957..0d3b51c0 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -16,6 +16,7 @@ import torch from transformers import GenerationConfig, PretrainedConfig +from transformers.tokenization_utils import PreTrainedTokenizer def save_config_to_constant_methods( @@ -65,7 +66,7 @@ def save_config_to_constant_methods( return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None} -def verify_eos_tokens_in_tokenizer(model_eos_ids: List[int], tokenizer) -> bool: +def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: PreTrainedTokenizer) -> bool: """ Verifies that the model's EOS token IDs are present in the tokenizer's set of potential end-of-sequence tokens. diff --git a/setup.py b/setup.py index c7fa93ed..e9bb5375 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ INSTALL_REQUIRE = [ "optimum~=1.24", "executorch>=0.6.0", - "transformers==4.51.3", + "transformers==4.53.2", ] TESTS_REQUIRE = [ diff --git a/tests/models/test_modeling_gemma3.py b/tests/models/test_modeling_gemma3.py index 77947c8a..d730200e 100644 --- a/tests/models/test_modeling_gemma3.py +++ b/tests/models/test_modeling_gemma3.py @@ -22,15 +22,17 @@ import unittest import pytest +import torch import torchao import transformers from executorch.extension.pybindings.portable_lib import ExecuTorchModule from packaging.version import parse -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor from transformers.testing_utils import slow from optimum.executorch import ExecuTorchModelForCausalLM from optimum.utils.import_utils import is_transformers_version +from optimum.exporters.executorch.tasks.image_text_to_text import load_image_text_to_text_model from ..utils import check_causal_lm_output_quality @@ -267,3 +269,112 @@ def test_gemma3_text_generation_with_custom_sdpa_kv_cache_8da4w_8we(self): gc.collect() self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif( + parse(transformers.__version__) < parse("4.53.0.dev0") or parse(torchao.__version__) < parse("0.11.0"), + reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0", + ) + @pytest.mark.skipif(is_linux_ci, reason="OOM on linux runner") + def test_gemma3_image_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we(self): + + model_id = "google/gemma-3-4b-it" + + module = load_image_text_to_text_model( + model_id, + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear=True, + qembedding_config=True, + ) + + res = module.export() + + # Generate + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_url = "https://llava-vl.github.io/static/images/view.jpg" + conversation = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}] + }, + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + { + "type": "text", + "text": "What are the things I should be cautious about when I visit here?", + }, + ], + }, + ] + processor = AutoProcessor.from_pretrained(model_id) + inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + image_indices = torch.where(inputs["input_ids"] == module.model.model.config.image_token_id) + prompt_before_image = inputs["input_ids"][:, :image_indices[1][0]] + prompt_after_image = inputs["input_ids"][:, image_indices[1][-1]+1:] + + image_features = res["vision_embeddings"].module().forward(pixel_values=inputs["pixel_values"]) + + print(prompt_before_image.shape) + + torch.arange(prompt_before_image.shape[1], device=inputs["input_ids"].device) + + token_embeddings_before_image = res["token_embeddings"].module().forward( + input_ids=prompt_before_image) + + token_embeddings_after_image = res["token_embeddings"].module().forward( + input_ids=prompt_after_image) + + embeddings = torch.cat( + [ + token_embeddings_before_image, + image_features, + token_embeddings_after_image, + ], + dim=1, + ) + + print(embeddings.shape) + + # Prefill prompt embeddings + logits = res["decoder"].module().forward( + inputs_embeds=embeddings, + cache_position=torch.arange(embeddings.shape[1], dtype=torch.long), + ) + + token = torch.argmax(logits[:, -1, :]) + + tokens = [token.item()] + + pos = embeddings.shape[1] + + while pos < 350: + token_embedding = res["token_embeddings"].module().forward( + input_ids=token.unsqueeze(0).unsqueeze(0) + ) + logits = res["decoder"].module().forward( + inputs_embeds=token_embedding, + cache_position=torch.tensor([pos], dtype=torch.long), + ) + token = torch.argmax(logits[:, -1, :]) + tokens.append(token.item()) + pos += 1 + + output = tokenizer.decode(tokens, skip_special_tokens=True) + self.assertEqual( + output, + """Okay, let's analyze the image and discuss potential cautions for visiting this location. + +Based on the picture, we're looking at a serene lakeside scene with a wooden pier extending into the water. Here's a breakdown of what you should be cautious about, categorized for clarity: + +**1""", + ) \ No newline at end of file diff --git a/tests/models/test_modeling_voxtral.py b/tests/models/test_modeling_voxtral.py new file mode 100644 index 00000000..d371c87e --- /dev/null +++ b/tests/models/test_modeling_voxtral.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os +import subprocess +import sys +import tempfile +import unittest + +import pytest +import torch +import torchao +import transformers +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from packaging.version import parse +from transformers import AutoConfig, AutoTokenizer, AutoProcessor +from transformers.testing_utils import slow + +from optimum.utils.import_utils import is_transformers_version +from optimum.executorch import ExecuTorchModelForMultiModalToText +from optimum.exporters.executorch.tasks.multimodal_text_to_text import load_multimodal_text_to_text_model + +from ..utils import check_causal_lm_output_quality + + +is_linux_ci = sys.platform.startswith("linux") and os.environ.get("GITHUB_ACTIONS") == "true" + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +logging.basicConfig(level=logging.DEBUG) + + +@pytest.mark.skipif( + is_transformers_version("<", "4.52.0.dev0"), + reason="Only available on transformers >= 4.52.0.dev0", +) +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Register custom SDPA, which is usually registered in the convert script. + from transformers.modeling_utils import AttentionInterface + from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward + + AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward) + if is_transformers_version(">=", "4.53.0.dev0"): + from transformers.integrations.executorch import sdpa_mask_without_vmap + from transformers.masking_utils import AttentionMaskInterface + + AttentionMaskInterface.register("custom_sdpa", sdpa_mask_without_vmap) + + # @slow + # @pytest.mark.run_slow + # @pytest.mark.skipif( + # parse(transformers.__version__) < parse("4.53.0.dev0") or parse(torchao.__version__) < parse("0.11.0"), + # reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0", + # ) + # @pytest.mark.skipif(is_linux_ci, reason="OOM on linux runner") + # @pytest.mark.skip() + def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we_exported_program(self): + model_id = "mistralai/Voxtral-Mini-3B-2507" + config = AutoConfig.from_pretrained(model_id) + module = load_multimodal_text_to_text_model( + model_id, + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear=True, + qembedding=True, + ) + + res = module.export() + + # Generate + tokenizer = AutoTokenizer.from_pretrained(model_id) + conversation = [ + { + "role": "user", + "content": [ + { + "type": "audio", + "url": "https://huggingface.co/datasets/eustlb/audio-samples/resolve/main/dude_where_is_my_car.wav", + }, + {"type": "text", "text": "What can you tell me about this audio?"}, + ], + } + ] + processor = AutoProcessor.from_pretrained(model_id) + inputs = processor.apply_chat_template( + conversation, + # add_generation_prompt=True, + # tokenize=True, + # return_dict=True, + # return_tensors="pt", + ) + + input_ids = inputs["input_ids"] + token_embeddings = res["token_embeddings"].module().forward( + input=input_ids) + + if "input_features" in inputs: + token_embeddings = res["audio_encoder"].module().forward( + input_features=inputs["input_features"], + inputs_embeds=token_embeddings, + input_ids=inputs["input_ids"], + ) + + # Prefill prompt embeddings + logits = res["decoder"].module().forward( + inputs_embeds=token_embeddings, + cache_position=torch.arange(token_embeddings.shape[1], dtype=torch.long), + ) + + token = torch.argmax(logits[:, -1, :]) + + tokens = [token.item()] + print(tokenizer.decode([token.item()]), end="") + + pos = token_embeddings.shape[1] + + while pos < 2000: + token_embedding = res["token_embeddings"].module().forward( + input=token.unsqueeze(0).unsqueeze(0) + ) + logits = res["decoder"].module().forward( + inputs_embeds=token_embedding, + cache_position=torch.tensor([pos], dtype=torch.long), + ) + token = torch.argmax(logits[:, -1, :]) + print(tokenizer.decode([token.item()]), end="") + tokens.append(token.item()) + pos += 1 + # TODO(JZ): end early. + + output = tokenizer.decode(tokens, skip_special_tokens=True) + self.assertTrue(output.startswith("The audio features a conversation between two individuals, likely friends or acquaintances, who are discussing a series of tattoos.")) + + def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we_pte(self): + model_id = "mistralai/Voxtral-Mini-3B-2507" + tokenizer = AutoTokenizer.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id) + conversation = [ + { + "role": "user", + "content": [ + { + "type": "audio", + "url": "https://huggingface.co/datasets/eustlb/audio-samples/resolve/main/dude_where_is_my_car.wav", + }, + {"type": "text", "text": "What can you tell me about this audio?"}, + ], + } + ] + + model = ExecuTorchModelForMultiModalToText.from_pretrained( + # model_id, + "/Users/jackzhxng/Documents/voxtral", # Load already exported model in local file path. + recipe="xnnpack", + attn_implementation="custom_sdpa", + use_custom_kv_cache=True, + **{"qlinear": True, "qembedding": True, "task": "multimodal-text-to-text"}, + ) + self.assertIsInstance(model, ExecuTorchModelForMultiModalToText) + self.assertIsInstance(model.model, ExecuTorchModule) + + generated_text = model.text_generation( + processor=processor, + tokenizer=tokenizer, + input_conversation=conversation, + max_seq_len=64, + ) + print(generated_text) + breakpoint() +