From 137e7d1ca5de276dc7e3f7d7fab2a55e44bcebdd Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 18 Sep 2025 14:47:47 -0700 Subject: [PATCH 1/8] fix versioning transformers error --- optimum/exporters/executorch/integrations.py | 191 ++++++++++++++----- 1 file changed, 141 insertions(+), 50 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 7cd1194..3d55541 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -16,6 +16,8 @@ from typing import Dict import torch + +from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache from packaging.version import parse from torch.export import ExportedProgram from torch.nn.attention import SDPBackend @@ -23,16 +25,19 @@ AutoProcessor, PreTrainedModel, StaticCache, + DynamicCache, + EncoderDecoderCache, T5ForConditionalGeneration, WhisperForConditionalGeneration, ) from transformers.generation.configuration_utils import GenerationConfig -from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap +from transformers.integrations.executorch import ( + sdpa_mask_without_vmap, + TorchExportableModuleForDecoderOnlyLM, +) from transformers.masking_utils import AttentionMaskInterface from transformers.modeling_utils import AttentionInterface -from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache - from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods @@ -49,7 +54,10 @@ def prepare_export_inputs(self): { "role": "user", "content": [ - {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + { + "type": "image", + "url": "https://llava-vl.github.io/static/images/view.jpg", + }, ], }, ] @@ -165,7 +173,9 @@ def __init__( super().__init__() if modality not in encoder_name: - raise ValueError(f'encoder_name "{encoder_name}" does not match specified modality "{modality}".') + raise ValueError( + f'encoder_name "{encoder_name}" does not match specified modality "{modality}".' + ) if not hasattr(model, encoder_name): raise ValueError(f'Model does not contain encoder "{encoder_name}".') @@ -178,9 +188,13 @@ def __init__( self.use_custom_sdpa = use_custom_sdpa additional_metadata_kwargs = {"modality": modality} if modality == "audio": - additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "audio_token_id") + additional_metadata_kwargs[f"{modality}_token_id"] = getattr( + self.config, "audio_token_id" + ) elif modality == "vision": - additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "image_token_id") + additional_metadata_kwargs[f"{modality}_token_id"] = getattr( + self.config, "image_token_id" + ) self.metadata = save_config_to_constant_methods( config=model.config.text_config, generation_config=model.generation_config, @@ -222,7 +236,9 @@ def _prepare_decoder_only_export_inputs(self, max_seq_len: int): # Prepare inputs with dynamic shapes seq_length = 3 - example_inputs_embeds = torch.zeros((1, seq_length, self.config.text_config.hidden_size), dtype=torch.float) + example_inputs_embeds = torch.zeros( + (1, seq_length, self.config.text_config.hidden_size), dtype=torch.float + ) example_cache_position = torch.arange(seq_length, dtype=torch.long) seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_len) @@ -234,18 +250,28 @@ def _prepare_decoder_only_export_inputs(self, max_seq_len: int): return example_inputs_embeds, example_cache_position, dynamic_shapes def _register_custom_attention(self, exportable_module: torch.nn.Module): - _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache( + exportable_module + ) if self.use_custom_sdpa: if self.use_custom_kv_cache: - AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) - AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) + AttentionInterface.register( + "custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache + ) + AttentionMaskInterface.register( + "custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap + ) # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" + exportable_module.model.model.config._attn_implementation = ( + "custom_sdpa_ring_kv_cache" + ) else: # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = "custom_sdpa" + exportable_module.model.model.config._attn_implementation = ( + "custom_sdpa" + ) def export( self, @@ -287,7 +313,9 @@ def export( self.model.dtype, ) - inputs_embeds, cache_position, dynamic_shapes = self._prepare_decoder_only_export_inputs(max_seq_len) + inputs_embeds, cache_position, dynamic_shapes = ( + self._prepare_decoder_only_export_inputs(max_seq_len) + ) logging.info( f"Exporting decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}" ) @@ -295,7 +323,7 @@ def export( inputs_embeds=inputs_embeds, cache_position=cache_position, dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) # Apply RemoveTransposes pass to remove # any back-to-back transpose ops that are not needed @@ -310,14 +338,19 @@ def export( mutated_gm, args=(), # For the ET runner, it's important to have cache position as the 2nd arg. - kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position}, + kwargs={ + "inputs_embeds": inputs_embeds, + "cache_position": cache_position, + }, dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) exported_programs["text_decoder"] = exported_program # 2. Export token embeddings - input_ids, dynamic_shapes = self._prepare_text_embedding_export_inputs(max_seq_len) + input_ids, dynamic_shapes = self._prepare_text_embedding_export_inputs( + max_seq_len + ) logging.info( f"Exporting token embeddings using input_ids({input_ids.shape}), dynamic_shapes={dynamic_shapes}" ) @@ -327,13 +360,15 @@ def export( args=(input_ids,), kwargs={}, dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) exported_programs["token_embedding"] = token_embedding_exported_program # 3. Export encoder. if self.use_custom_sdpa: - getattr(self.model, self.encoder_name).config._attn_implementation = "custom_sdpa" + getattr(self.model, self.encoder_name).config._attn_implementation = ( + "custom_sdpa" + ) if self.modality == "audio": encoder = AudioExportableModule(self.model) @@ -356,7 +391,7 @@ def export( "input_features": input_features, }, dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) exported_programs[f"{self.modality}_encoder"] = encoder_exported_program @@ -370,7 +405,12 @@ class CausalLMExportableModule(torch.nn.Module): """ def __init__( - self, model, max_seq_len=2048, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False + self, + model, + max_seq_len=2048, + use_custom_kv_cache=False, + use_custom_sdpa=False, + disable_dynamic_shapes=False, ): super().__init__() self.model = model @@ -405,7 +445,10 @@ def _prepare_export_inputs(self): and not (self.use_custom_kv_cache and self.use_custom_sdpa) ) - if not self.disable_dynamic_shapes and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache: + if ( + not self.disable_dynamic_shapes + and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache + ): # Prepare inputs with dynamic shapes seq_length = 3 # Sequence length > 1 to avoid specialization issues example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) @@ -418,7 +461,9 @@ def _prepare_export_inputs(self): "input_ids": {1: seq_len_dim}, "cache_position": {0: seq_len_dim}, } - strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994 + strict = parse(torch.__version__) != parse( + "2.7.0" + ) # Workaround for PyTorch bug #150994 return example_input_ids, example_cache_position, dynamic_shapes, strict @@ -429,21 +474,33 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module): if self.use_custom_sdpa: if self.use_custom_kv_cache: - _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) - AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) - AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache( + exportable_module + ) + AttentionInterface.register( + "custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache + ) + AttentionMaskInterface.register( + "custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap + ) # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" + exportable_module.model.model.config._attn_implementation = ( + "custom_sdpa_ring_kv_cache" + ) else: # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = "custom_sdpa" + exportable_module.model.model.config._attn_implementation = ( + "custom_sdpa" + ) def export( self, ) -> Dict[str, ExportedProgram]: - input_ids, cache_position, dynamic_shapes, strict = self._prepare_export_inputs() + input_ids, cache_position, dynamic_shapes, strict = ( + self._prepare_export_inputs() + ) logging.info( f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}" ) @@ -467,7 +524,10 @@ def export( with torch.no_grad(): exported_program = exportable_module.export( - input_ids=input_ids, cache_position=cache_position, dynamic_shapes=dynamic_shapes, strict=strict + input_ids=input_ids, + cache_position=cache_position, + dynamic_shapes=dynamic_shapes, + strict=strict, ) # Apply RemoveTransposes pass to remove # any back-to-back transpose ops that are not needed @@ -503,7 +563,9 @@ def __init__(self, model): self.model = model self.config = model.config # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods(model.config, model.generation_config) + self.metadata = save_config_to_constant_methods( + model.config, model.generation_config + ) def forward(self, pixel_values): print(f"DEBUG: pixel_values: {pixel_values.shape}") @@ -540,13 +602,17 @@ def __init__(self, model): self.model = model self.config = model.config # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods(model.config, model.generation_config) + self.metadata = save_config_to_constant_methods( + model.config, model.generation_config + ) def forward(self, input_ids, attention_mask): return self.model(input_ids, attention_mask) def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgram]: - max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 64) + max_position_embeddings = getattr( + self.model.config, "max_position_embeddings", 64 + ) max_seq_length = max(max_position_embeddings - 1, 1) # Create dummy inputs with expected shapes batch_size = 1 @@ -560,7 +626,9 @@ def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgr else input_ids ) dummy_attention_mask = ( - torch.ones((batch_size, seq_length), dtype=torch.long) if attention_mask is None else attention_mask + torch.ones((batch_size, seq_length), dtype=torch.long) + if attention_mask is None + else attention_mask ) # Define dynamic shapes with Dim objects, always use Auto @@ -577,7 +645,7 @@ def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgr args=(dummy_input_ids,), kwargs={"attention_mask": dummy_attention_mask}, dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) } @@ -623,18 +691,23 @@ def __init__(self, model, max_static_cache_length, batch_size): device="cpu", dtype=torch.float32, ) + self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + # for i in range(len(self.static_cache.layers)): + # self.register_buffer( + # f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False + # ) + # self.register_buffer( + # f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False + # ) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder outputs = self.decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_hidden_states, - past_key_values=self.static_cache, + past_key_values=self.cache, use_cache=True, cache_position=cache_position, ) @@ -689,7 +762,9 @@ def __init__( self.exported_decoder = None def _export_encoder(self, encoder_input_ids): - wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() + wrapped_encoder = ( + Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() + ) # Define dynamic sequence length for encoder if isinstance(self.full_model, WhisperForConditionalGeneration): @@ -699,7 +774,9 @@ def _export_encoder(self, encoder_input_ids): For more infromation, please refer to the Whisper preprocessor config.""" dynamic_shapes = None elif isinstance(self.full_model, T5ForConditionalGeneration): - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + encoder_seq_len_dim = torch.export.Dim( + "encoder_hidden_seq_length", max=self.max_hidden_seq_length + ) dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}} else: raise ValueError( @@ -712,7 +789,7 @@ def _export_encoder(self, encoder_input_ids): wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) return exported_encoder @@ -720,7 +797,9 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi wrapped_decoder = ( Seq2SeqLMDecoderExportableModuleWithStaticCache( model=self.full_model, - max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"), + max_static_cache_length=self.generation_config.cache_config.get( + "max_cache_len" + ), batch_size=self.generation_config.cache_config.get("batch_size"), ) .to("cpu") @@ -731,7 +810,9 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi dynamic_shapes = None elif isinstance(self.full_model, T5ForConditionalGeneration): # Define dynamic dimension for encoder output sequence length - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + encoder_seq_len_dim = torch.export.Dim( + "encoder_hidden_seq_length", max=self.max_hidden_seq_length + ) dynamic_shapes = { "decoder_input_ids": None, "encoder_hidden_states": {1: encoder_seq_len_dim}, @@ -748,7 +829,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi wrapped_decoder, (decoder_input_ids, encoder_hidden_states, cache_position), dynamic_shapes=dynamic_shapes, - strict=True, + strict=False, ) return exported_decoder @@ -762,7 +843,9 @@ def export( ) -> Dict[str, ExportedProgram]: if encoder_input_ids is None: if isinstance(self.full_model, WhisperForConditionalGeneration): - example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape) + example_encoder_input_ids = torch.rand( + self._expected_encoder_input_shape + ) else: example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) else: @@ -771,14 +854,22 @@ def export( self.exported_encoder = self._export_encoder(example_encoder_input_ids) if not encoder_hidden_states: - example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids) + example_encoder_hidden_states = self.exported_encoder.module()( + example_encoder_input_ids + ) else: example_encoder_hidden_states = encoder_hidden_states example_decoder_input_ids = ( - decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) + decoder_input_ids + if decoder_input_ids is not None + else torch.tensor([[0]], dtype=torch.long) + ) + example_cache_position = ( + cache_position + if cache_position is not None + else torch.tensor([0], dtype=torch.long) ) - example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) self.exported_decoder = self._export_decoder( example_decoder_input_ids, From dfd408a5fd0081a9b4d0a91466244e2248f6f43c Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 18 Sep 2025 16:47:07 -0700 Subject: [PATCH 2/8] Switch to EncoderDecoder fake cache --- optimum/exporters/executorch/integrations.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 3d55541..40d1797 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -323,7 +323,7 @@ def export( inputs_embeds=inputs_embeds, cache_position=cache_position, dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) # Apply RemoveTransposes pass to remove # any back-to-back transpose ops that are not needed @@ -343,7 +343,7 @@ def export( "cache_position": cache_position, }, dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) exported_programs["text_decoder"] = exported_program @@ -360,7 +360,7 @@ def export( args=(input_ids,), kwargs={}, dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) exported_programs["token_embedding"] = token_embedding_exported_program @@ -391,7 +391,7 @@ def export( "input_features": input_features, }, dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) exported_programs[f"{self.modality}_encoder"] = encoder_exported_program @@ -586,7 +586,7 @@ def export(self, pixel_values=None) -> Dict[str, ExportedProgram]: self.model, args=(), kwargs={"pixel_values": pixel_values}, - strict=False, + strict=True, ) } @@ -645,7 +645,7 @@ def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgr args=(dummy_input_ids,), kwargs={"attention_mask": dummy_attention_mask}, dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) } @@ -789,7 +789,7 @@ def _export_encoder(self, encoder_input_ids): wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) return exported_encoder @@ -829,7 +829,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi wrapped_decoder, (decoder_input_ids, encoder_hidden_states, cache_position), dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) return exported_decoder From 7b81996e407dbcded119670d5788cf1cdcc71026 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 18 Sep 2025 16:49:10 -0700 Subject: [PATCH 3/8] lint --- optimum/exporters/executorch/integrations.py | 141 +++++-------------- 1 file changed, 37 insertions(+), 104 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 40d1797..c0d5fb9 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -16,28 +16,28 @@ from typing import Dict import torch - -from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache from packaging.version import parse from torch.export import ExportedProgram from torch.nn.attention import SDPBackend from transformers import ( AutoProcessor, - PreTrainedModel, - StaticCache, DynamicCache, EncoderDecoderCache, + PreTrainedModel, + StaticCache, T5ForConditionalGeneration, WhisperForConditionalGeneration, ) from transformers.generation.configuration_utils import GenerationConfig from transformers.integrations.executorch import ( - sdpa_mask_without_vmap, TorchExportableModuleForDecoderOnlyLM, + sdpa_mask_without_vmap, ) from transformers.masking_utils import AttentionMaskInterface from transformers.modeling_utils import AttentionInterface +from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache + from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods @@ -173,9 +173,7 @@ def __init__( super().__init__() if modality not in encoder_name: - raise ValueError( - f'encoder_name "{encoder_name}" does not match specified modality "{modality}".' - ) + raise ValueError(f'encoder_name "{encoder_name}" does not match specified modality "{modality}".') if not hasattr(model, encoder_name): raise ValueError(f'Model does not contain encoder "{encoder_name}".') @@ -188,13 +186,9 @@ def __init__( self.use_custom_sdpa = use_custom_sdpa additional_metadata_kwargs = {"modality": modality} if modality == "audio": - additional_metadata_kwargs[f"{modality}_token_id"] = getattr( - self.config, "audio_token_id" - ) + additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "audio_token_id") elif modality == "vision": - additional_metadata_kwargs[f"{modality}_token_id"] = getattr( - self.config, "image_token_id" - ) + additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "image_token_id") self.metadata = save_config_to_constant_methods( config=model.config.text_config, generation_config=model.generation_config, @@ -236,9 +230,7 @@ def _prepare_decoder_only_export_inputs(self, max_seq_len: int): # Prepare inputs with dynamic shapes seq_length = 3 - example_inputs_embeds = torch.zeros( - (1, seq_length, self.config.text_config.hidden_size), dtype=torch.float - ) + example_inputs_embeds = torch.zeros((1, seq_length, self.config.text_config.hidden_size), dtype=torch.float) example_cache_position = torch.arange(seq_length, dtype=torch.long) seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_len) @@ -250,28 +242,18 @@ def _prepare_decoder_only_export_inputs(self, max_seq_len: int): return example_inputs_embeds, example_cache_position, dynamic_shapes def _register_custom_attention(self, exportable_module: torch.nn.Module): - _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache( - exportable_module - ) + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) if self.use_custom_sdpa: if self.use_custom_kv_cache: - AttentionInterface.register( - "custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache - ) - AttentionMaskInterface.register( - "custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap - ) + AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) + AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = ( - "custom_sdpa_ring_kv_cache" - ) + exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" else: # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = ( - "custom_sdpa" - ) + exportable_module.model.model.config._attn_implementation = "custom_sdpa" def export( self, @@ -313,9 +295,7 @@ def export( self.model.dtype, ) - inputs_embeds, cache_position, dynamic_shapes = ( - self._prepare_decoder_only_export_inputs(max_seq_len) - ) + inputs_embeds, cache_position, dynamic_shapes = self._prepare_decoder_only_export_inputs(max_seq_len) logging.info( f"Exporting decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}" ) @@ -348,9 +328,7 @@ def export( exported_programs["text_decoder"] = exported_program # 2. Export token embeddings - input_ids, dynamic_shapes = self._prepare_text_embedding_export_inputs( - max_seq_len - ) + input_ids, dynamic_shapes = self._prepare_text_embedding_export_inputs(max_seq_len) logging.info( f"Exporting token embeddings using input_ids({input_ids.shape}), dynamic_shapes={dynamic_shapes}" ) @@ -366,9 +344,7 @@ def export( # 3. Export encoder. if self.use_custom_sdpa: - getattr(self.model, self.encoder_name).config._attn_implementation = ( - "custom_sdpa" - ) + getattr(self.model, self.encoder_name).config._attn_implementation = "custom_sdpa" if self.modality == "audio": encoder = AudioExportableModule(self.model) @@ -445,10 +421,7 @@ def _prepare_export_inputs(self): and not (self.use_custom_kv_cache and self.use_custom_sdpa) ) - if ( - not self.disable_dynamic_shapes - and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache - ): + if not self.disable_dynamic_shapes and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache: # Prepare inputs with dynamic shapes seq_length = 3 # Sequence length > 1 to avoid specialization issues example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) @@ -461,9 +434,7 @@ def _prepare_export_inputs(self): "input_ids": {1: seq_len_dim}, "cache_position": {0: seq_len_dim}, } - strict = parse(torch.__version__) != parse( - "2.7.0" - ) # Workaround for PyTorch bug #150994 + strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994 return example_input_ids, example_cache_position, dynamic_shapes, strict @@ -474,33 +445,21 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module): if self.use_custom_sdpa: if self.use_custom_kv_cache: - _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache( - exportable_module - ) - AttentionInterface.register( - "custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache - ) - AttentionMaskInterface.register( - "custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap - ) + _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) + AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) + AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = ( - "custom_sdpa_ring_kv_cache" - ) + exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" else: # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention - exportable_module.model.model.config._attn_implementation = ( - "custom_sdpa" - ) + exportable_module.model.model.config._attn_implementation = "custom_sdpa" def export( self, ) -> Dict[str, ExportedProgram]: - input_ids, cache_position, dynamic_shapes, strict = ( - self._prepare_export_inputs() - ) + input_ids, cache_position, dynamic_shapes, strict = self._prepare_export_inputs() logging.info( f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}" ) @@ -563,9 +522,7 @@ def __init__(self, model): self.model = model self.config = model.config # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods( - model.config, model.generation_config - ) + self.metadata = save_config_to_constant_methods(model.config, model.generation_config) def forward(self, pixel_values): print(f"DEBUG: pixel_values: {pixel_values.shape}") @@ -602,17 +559,13 @@ def __init__(self, model): self.model = model self.config = model.config # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods( - model.config, model.generation_config - ) + self.metadata = save_config_to_constant_methods(model.config, model.generation_config) def forward(self, input_ids, attention_mask): return self.model(input_ids, attention_mask) def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgram]: - max_position_embeddings = getattr( - self.model.config, "max_position_embeddings", 64 - ) + max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 64) max_seq_length = max(max_position_embeddings - 1, 1) # Create dummy inputs with expected shapes batch_size = 1 @@ -626,9 +579,7 @@ def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgr else input_ids ) dummy_attention_mask = ( - torch.ones((batch_size, seq_length), dtype=torch.long) - if attention_mask is None - else attention_mask + torch.ones((batch_size, seq_length), dtype=torch.long) if attention_mask is None else attention_mask ) # Define dynamic shapes with Dim objects, always use Auto @@ -762,9 +713,7 @@ def __init__( self.exported_decoder = None def _export_encoder(self, encoder_input_ids): - wrapped_encoder = ( - Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() - ) + wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() # Define dynamic sequence length for encoder if isinstance(self.full_model, WhisperForConditionalGeneration): @@ -774,9 +723,7 @@ def _export_encoder(self, encoder_input_ids): For more infromation, please refer to the Whisper preprocessor config.""" dynamic_shapes = None elif isinstance(self.full_model, T5ForConditionalGeneration): - encoder_seq_len_dim = torch.export.Dim( - "encoder_hidden_seq_length", max=self.max_hidden_seq_length - ) + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}} else: raise ValueError( @@ -797,9 +744,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi wrapped_decoder = ( Seq2SeqLMDecoderExportableModuleWithStaticCache( model=self.full_model, - max_static_cache_length=self.generation_config.cache_config.get( - "max_cache_len" - ), + max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"), batch_size=self.generation_config.cache_config.get("batch_size"), ) .to("cpu") @@ -810,9 +755,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi dynamic_shapes = None elif isinstance(self.full_model, T5ForConditionalGeneration): # Define dynamic dimension for encoder output sequence length - encoder_seq_len_dim = torch.export.Dim( - "encoder_hidden_seq_length", max=self.max_hidden_seq_length - ) + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) dynamic_shapes = { "decoder_input_ids": None, "encoder_hidden_states": {1: encoder_seq_len_dim}, @@ -843,9 +786,7 @@ def export( ) -> Dict[str, ExportedProgram]: if encoder_input_ids is None: if isinstance(self.full_model, WhisperForConditionalGeneration): - example_encoder_input_ids = torch.rand( - self._expected_encoder_input_shape - ) + example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape) else: example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) else: @@ -854,22 +795,14 @@ def export( self.exported_encoder = self._export_encoder(example_encoder_input_ids) if not encoder_hidden_states: - example_encoder_hidden_states = self.exported_encoder.module()( - example_encoder_input_ids - ) + example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids) else: example_encoder_hidden_states = encoder_hidden_states example_decoder_input_ids = ( - decoder_input_ids - if decoder_input_ids is not None - else torch.tensor([[0]], dtype=torch.long) - ) - example_cache_position = ( - cache_position - if cache_position is not None - else torch.tensor([0], dtype=torch.long) + decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) ) + example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) self.exported_decoder = self._export_decoder( example_decoder_input_ids, From c37a47d8e0b8d8be510d5e4e011504df84ea3c64 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 18 Sep 2025 16:51:10 -0700 Subject: [PATCH 4/8] remove unnecessary changes --- optimum/exporters/executorch/integrations.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index c0d5fb9..1c0cddb 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -543,7 +543,7 @@ def export(self, pixel_values=None) -> Dict[str, ExportedProgram]: self.model, args=(), kwargs={"pixel_values": pixel_values}, - strict=True, + strict=False, ) } @@ -644,15 +644,6 @@ def __init__(self, model, max_static_cache_length, batch_size): ) self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) - # Register cache buffers to make them exportable - # for i in range(len(self.static_cache.layers)): - # self.register_buffer( - # f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False - # ) - # self.register_buffer( - # f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False - # ) - def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder outputs = self.decoder( From c2b7e71279ac5ef3213a8194df232668997569a7 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 24 Sep 2025 10:38:14 -0700 Subject: [PATCH 5/8] decoder -> text_decoder --- optimum/executorch/modeling.py | 205 ++++++++++++++++++++++++--------- 1 file changed, 152 insertions(+), 53 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 93a0034..065fbbc 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -23,10 +23,17 @@ from typing import Dict, List, Optional, Union import torch + +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch, + ExecuTorchModule, +) +from executorch.kernels import quantized # noqa from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa from transformers import ( + add_start_docstrings, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, @@ -34,15 +41,11 @@ AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, PreTrainedTokenizer, - add_start_docstrings, ) from transformers.configuration_utils import PretrainedConfig from transformers.processing_utils import ProcessorMixin from transformers.utils import is_offline_mode -from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch -from executorch.kernels import quantized # noqa - from ..exporters import TasksManager from ..exporters.executorch import main_export from ..exporters.executorch.utils import ( @@ -126,7 +129,9 @@ def _cleanup_temp_resources(self): try: if hasattr(self._temp_dir, "cleanup"): # It's a TemporaryDirectory object - logging.info(f"Cleaning up temporary directory: {self._temp_dir.name}") + logging.info( + f"Cleaning up temporary directory: {self._temp_dir.name}" + ) self._temp_dir.cleanup() logging.info("Temporary directory cleanup completed") elif isinstance(self._temp_dir, (str, Path)): @@ -176,7 +181,9 @@ def _from_pretrained( _PTE_SUFFIX = ".pte" if file_name and not file_name.endswith(_PTE_SUFFIX): - raise ValueError(f"Invalid file name: {file_name}. Expected a '{_PTE_SUFFIX}' file.") + raise ValueError( + f"Invalid file name: {file_name}. Expected a '{_PTE_SUFFIX}' file." + ) default_file_name = file_name or "model.pte" @@ -190,9 +197,13 @@ def _from_pretrained( ) if len(pte_files) == 0: - raise FileNotFoundError(f"Could not find any ExecuTorch model file in {model_id}") + raise FileNotFoundError( + f"Could not find any ExecuTorch model file in {model_id}" + ) if len(pte_files) == 1 and file_name and file_name != pte_files[0].name: - raise FileNotFoundError(f"Trying to load {file_name} but only found {pte_files[0].name}") + raise FileNotFoundError( + f"Trying to load {file_name} but only found {pte_files[0].name}" + ) file_name = pte_files[0].name subfolder = pte_files[0].parent @@ -276,7 +287,11 @@ def _export( **kwargs, ) -> Dict[str, "ExecuTorchModule"]: task = kwargs.pop("task", None) - inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) if not task else task + inferred_task = ( + TasksManager.infer_task_from_model(cls.auto_model_class) + if not task + else task + ) logging.info(f"Inferred task from model class: {inferred_task}") save_dir = TemporaryDirectory(prefix="executorch_export_") @@ -301,7 +316,11 @@ def _export( models = {} for name, _ in executorch_progs.items(): - models.update(cls._from_pretrained(save_dir_path, file_name=f"{name}.pte", config=config)) + models.update( + cls._from_pretrained( + save_dir_path, file_name=f"{name}.pte", config=config + ) + ) return models, save_dir @@ -341,7 +360,9 @@ def from_pretrained( if local_files_only and not os.path.isdir(model_id): object_id = model_id.replace("/", "--") cached_model_dir = os.path.join(cache_dir, f"models--{object_id}") - refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main") + refs_file = os.path.join( + os.path.join(cached_model_dir, "refs"), revision or "main" + ) with open(refs_file) as f: _revision = f.read() model_dir = os.path.join(cached_model_dir, "snapshots", _revision) @@ -458,28 +479,36 @@ def __init__( ): super().__init__(models=models, config=config) if not hasattr(self, "encoder"): - raise AttributeError("Expected attribute 'encoder' not found in the instance.") + raise AttributeError( + "Expected attribute 'encoder' not found in the instance." + ) if not hasattr(self, "text_decoder"): - raise AttributeError("Expected attribute 'decoder' not found in the instance.") - metadata = self.decoder.method_names() + raise AttributeError( + "Expected attribute 'text_decoder' not found in the instance." + ) + metadata = self.text_decoder.method_names() if "use_kv_cache" in metadata: - self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] + self.use_kv_cache = self.text_decoder.run_method("use_kv_cache")[0] if "get_max_seq_len" in metadata: - self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0] + self.max_cache_size = self.text_decoder.run_method("get_max_seq_len")[0] if "get_max_batch_size" in metadata: - self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0] + self.max_batch_size = self.text_decoder.run_method("get_max_batch_size")[0] if "get_dtype" in metadata: - self.dtype = self.decoder.run_method("get_dtype")[0] + self.dtype = self.text_decoder.run_method("get_dtype")[0] if "get_bos_id" in metadata: - self.bos_token_id = self.decoder.run_method("get_bos_id")[0] + self.bos_token_id = self.text_decoder.run_method("get_bos_id")[0] if "get_eos_id" in metadata: - self.eos_token_id = self.decoder.run_method("get_eos_id")[0] + self.eos_token_id = self.text_decoder.run_method("get_eos_id")[0] if "get_vocab_size" in metadata: - self.vocab_size = self.decoder.run_method("get_vocab_size")[0] + self.vocab_size = self.text_decoder.run_method("get_vocab_size")[0] if "max_hidden_seq_length" in metadata: - self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0] + self.max_hidden_seq_length = self.text_decoder.run_method( + "max_hidden_seq_length" + )[0] if "decoder_start_token_id" in metadata: - self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0] + self.decoder_start_token_id = self.text_decoder.run_method( + "decoder_start_token_id" + )[0] def forward( self, @@ -495,7 +524,12 @@ def forward( encoder_outputs = self.encoder.forward((input_ids,))[0] self.stats.on_prompt_eval_end() - result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + result = ( + self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[ + 0 + ], + encoder_outputs, + ) self.stats.on_model_execution_end() return result @@ -539,8 +573,12 @@ def generate( max_seq_len = self.max_cache_size if not hasattr(self, "decoder_start_token_id"): - raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.") - decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long) + raise AttributeError( + "'decoder_start_token_id' is missing in the metadata of the PTE." + ) + decoder_input_ids = torch.tensor( + [[self.decoder_start_token_id]], dtype=torch.long + ) encoder_input_ids = input_ids encoder_outputs = None generated_ids = [0] @@ -563,7 +601,9 @@ def generate( # Get next token next_token = torch.argmax(logits[:, -1, :], dim=-1).item() generated_ids.append(next_token) - self.stats.set_num_generated_tokens(len(generated_ids) - 1) # Don't count decoder_start_token + self.stats.set_num_generated_tokens( + len(generated_ids) - 1 + ) # Don't count decoder_start_token # Update input for next iteration decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) @@ -658,7 +698,9 @@ def __init__( ): super().__init__(models, config) if not hasattr(self, "model"): - raise AttributeError("Expected attribute 'model' not found in the instance.") + raise AttributeError( + "Expected attribute 'model' not found in the instance." + ) metadata = self.model.method_names() logging.debug(f"Load all static methods: {metadata}") if "use_kv_cache" in metadata: @@ -678,7 +720,9 @@ def __init__( if "get_vocab_size" in metadata: self.vocab_size = self.model.run_method("get_vocab_size")[0] if "use_sdpa_with_kv_cache" in metadata: - self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0] + self.use_sdpa_with_kv_cache = self.model.run_method( + "use_sdpa_with_kv_cache" + )[0] def forward( self, @@ -700,8 +744,14 @@ def forward( try: logits = self.model.forward((input_ids, cache_position))[0] except Exception as e: - shapes = {name: val.shape for name, val in locals().items() if hasattr(val, "shape")} - print(f"Exception: {e}.\n{self.model.method_meta('forward')}\narg shapes: {shapes}") + shapes = { + name: val.shape + for name, val in locals().items() + if hasattr(val, "shape") + } + print( + f"Exception: {e}.\n{self.model.method_meta('forward')}\narg shapes: {shapes}" + ) raise self.stats.on_model_execution_end() @@ -752,8 +802,12 @@ def generate( # The model is exported with dynamic shapes. Can support parallel prefill. self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0), - cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device), + input_ids=torch.tensor( + prompt_tokens, dtype=torch.long, device=self.device + ).unsqueeze(0), + cache_position=torch.arange( + len(prompt_tokens), dtype=torch.long, device=self.device + ), ) self.stats.on_sampling_end() next_token = torch.argmax(logits, dim=-1)[0, -1].item() @@ -763,8 +817,12 @@ def generate( for i, prompt_token in enumerate(prompt_tokens): self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), - cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + input_ids=torch.tensor( + [prompt_token], dtype=torch.long, device=self.device + ).unsqueeze(0), + cache_position=torch.tensor( + [i], dtype=torch.long, device=self.device + ), ) self.stats.on_sampling_end() next_token = torch.argmax(logits, dim=-1).item() @@ -776,7 +834,9 @@ def generate( while len(generated_tokens) < max_seq_len: self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + input_ids=torch.tensor( + [next_token], dtype=torch.long, device=self.device + ).unsqueeze(0), cache_position=torch.tensor( [pos_base + len(generated_tokens) - 1], dtype=torch.long, @@ -823,11 +883,16 @@ def text_generation( self.tokenizer = tokenizer # Sanity check - if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + if ( + self.tokenizer.bos_token_id is not None + and self.tokenizer.bos_token_id != self.bos_token_id + ): raise ValueError( f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." ) - if not verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_ids, self.tokenizer): + if not verify_eos_tokens_in_pretrained_tokenizer( + self.eos_token_ids, self.tokenizer + ): raise ValueError( f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}." ) @@ -884,7 +949,9 @@ def __init__( ): super().__init__(models, config) if not hasattr(self, "model"): - raise AttributeError("Expected attribute 'model' not found in the instance.") + raise AttributeError( + "Expected attribute 'model' not found in the instance." + ) metadata = self.model.method_names() logging.debug(f"Load all static methods: {metadata}") if "get_max_seq_len" in metadata: @@ -960,7 +1027,9 @@ def __init__( ): super().__init__(models, config) if not hasattr(self, "model"): - raise AttributeError("Expected attribute 'model' not found in the instance.") + raise AttributeError( + "Expected attribute 'model' not found in the instance." + ) metadata = self.model.method_names() logging.debug(f"Load all static methods: {metadata}") @@ -1023,9 +1092,13 @@ def __init__( ): super().__init__(models=models, config=config) if not hasattr(self, "encoder"): - raise AttributeError("Expected attribute 'encoder' not found in the instance.") + raise AttributeError( + "Expected attribute 'encoder' not found in the instance." + ) if not hasattr(self, "text_decoder"): - raise AttributeError("Expected attribute 'decoder' not found in the instance.") + raise AttributeError( + "Expected attribute 'decoder' not found in the instance." + ) metadata = self.decoder.method_names() if "use_kv_cache" in metadata: self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] @@ -1042,9 +1115,13 @@ def __init__( if "get_vocab_size" in metadata: self.vocab_size = self.decoder.run_method("get_vocab_size")[0] if "max_hidden_seq_length" in metadata: - self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0] + self.max_hidden_seq_length = self.decoder.run_method( + "max_hidden_seq_length" + )[0] if "decoder_start_token_id" in metadata: - self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0] + self.decoder_start_token_id = self.decoder.run_method( + "decoder_start_token_id" + )[0] def forward( self, @@ -1059,7 +1136,12 @@ def forward( encoder_outputs = self.encoder.forward((input_features,))[0] self.stats.on_prompt_eval_end() - result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + result = ( + self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[ + 0 + ], + encoder_outputs, + ) self.stats.on_model_execution_end() return result @@ -1100,8 +1182,12 @@ def generate( max_seq_len = self.max_cache_size if not hasattr(self, "decoder_start_token_id"): - raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.") - decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long) + raise AttributeError( + "'decoder_start_token_id' is missing in the metadata of the PTE." + ) + decoder_input_ids = torch.tensor( + [[self.decoder_start_token_id]], dtype=torch.long + ) log_mel = input_features encoder_outputs = None generated_ids = [] @@ -1112,7 +1198,9 @@ def generate( # Run decoder for next token prediction cache_position = torch.tensor([i], dtype=torch.long) self.stats.on_sampling_begin() - logits, encoder_outputs = self.forward(log_mel, decoder_input_ids, cache_position, encoder_outputs) + logits, encoder_outputs = self.forward( + log_mel, decoder_input_ids, cache_position, encoder_outputs + ) self.stats.on_sampling_end() if not first_token_generated: self.stats.on_first_token() @@ -1120,7 +1208,9 @@ def generate( # Get next token next_token = torch.argmax(logits[:, -1, :], dim=-1).item() generated_ids.append(next_token) - self.stats.set_num_generated_tokens(len(generated_ids) - 1) # Don't count decoder_start_token + self.stats.set_num_generated_tokens( + len(generated_ids) - 1 + ) # Don't count decoder_start_token # Update input for next iteration decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) @@ -1281,7 +1371,9 @@ def generate( self.stats.on_sampling_begin() logits = self.forward( input_ids=prompt_tokens, - cache_position=torch.arange(prompt_tokens.size(1), dtype=torch.long, device=self.device), + cache_position=torch.arange( + prompt_tokens.size(1), dtype=torch.long, device=self.device + ), multimodal_features=multimodal_features, ) self.stats.on_sampling_end() @@ -1296,7 +1388,9 @@ def generate( while len(generated_tokens) + prompt_tokens.size(1) < max_seq_len: self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + input_ids=torch.tensor( + [next_token], dtype=torch.long, device=self.device + ).unsqueeze(0), cache_position=torch.tensor( [pos_base + len(generated_tokens) + prompt_tokens.size(1) - 1], dtype=torch.long, @@ -1345,11 +1439,16 @@ def text_generation( self.tokenizer = tokenizer # Sanity check - if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + if ( + self.tokenizer.bos_token_id is not None + and self.tokenizer.bos_token_id != self.bos_token_id + ): raise ValueError( f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." ) - if isinstance(self.tokenizer, PreTrainedTokenizer) and not verify_eos_tokens_in_pretrained_tokenizer( + if isinstance( + self.tokenizer, PreTrainedTokenizer + ) and not verify_eos_tokens_in_pretrained_tokenizer( self.eos_token_id, self.tokenizer ): raise ValueError( From 9918fa2c15cd9dc29b549ffd35314a01c603fd5f Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 24 Sep 2025 11:33:16 -0700 Subject: [PATCH 6/8] self.encoder/decoder -> self.model --- optimum/executorch/modeling.py | 222 ++++++++------------------ tests/models/test_modeling_whisper.py | 14 +- 2 files changed, 72 insertions(+), 164 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 065fbbc..f7339ec 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -23,17 +23,10 @@ from typing import Dict, List, Optional, Union import torch - -from executorch.extension.pybindings.portable_lib import ( - _load_for_executorch, - ExecuTorchModule, -) -from executorch.kernels import quantized # noqa from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa from transformers import ( - add_start_docstrings, AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, @@ -41,11 +34,18 @@ AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, PreTrainedTokenizer, + add_start_docstrings, ) from transformers.configuration_utils import PretrainedConfig from transformers.processing_utils import ProcessorMixin from transformers.utils import is_offline_mode +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, + _load_for_executorch, +) +from executorch.kernels import quantized # noqa + from ..exporters import TasksManager from ..exporters.executorch import main_export from ..exporters.executorch.utils import ( @@ -129,9 +129,7 @@ def _cleanup_temp_resources(self): try: if hasattr(self._temp_dir, "cleanup"): # It's a TemporaryDirectory object - logging.info( - f"Cleaning up temporary directory: {self._temp_dir.name}" - ) + logging.info(f"Cleaning up temporary directory: {self._temp_dir.name}") self._temp_dir.cleanup() logging.info("Temporary directory cleanup completed") elif isinstance(self._temp_dir, (str, Path)): @@ -181,9 +179,7 @@ def _from_pretrained( _PTE_SUFFIX = ".pte" if file_name and not file_name.endswith(_PTE_SUFFIX): - raise ValueError( - f"Invalid file name: {file_name}. Expected a '{_PTE_SUFFIX}' file." - ) + raise ValueError(f"Invalid file name: {file_name}. Expected a '{_PTE_SUFFIX}' file.") default_file_name = file_name or "model.pte" @@ -197,13 +193,9 @@ def _from_pretrained( ) if len(pte_files) == 0: - raise FileNotFoundError( - f"Could not find any ExecuTorch model file in {model_id}" - ) + raise FileNotFoundError(f"Could not find any ExecuTorch model file in {model_id}") if len(pte_files) == 1 and file_name and file_name != pte_files[0].name: - raise FileNotFoundError( - f"Trying to load {file_name} but only found {pte_files[0].name}" - ) + raise FileNotFoundError(f"Trying to load {file_name} but only found {pte_files[0].name}") file_name = pte_files[0].name subfolder = pte_files[0].parent @@ -287,11 +279,7 @@ def _export( **kwargs, ) -> Dict[str, "ExecuTorchModule"]: task = kwargs.pop("task", None) - inferred_task = ( - TasksManager.infer_task_from_model(cls.auto_model_class) - if not task - else task - ) + inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) if not task else task logging.info(f"Inferred task from model class: {inferred_task}") save_dir = TemporaryDirectory(prefix="executorch_export_") @@ -316,11 +304,7 @@ def _export( models = {} for name, _ in executorch_progs.items(): - models.update( - cls._from_pretrained( - save_dir_path, file_name=f"{name}.pte", config=config - ) - ) + models.update(cls._from_pretrained(save_dir_path, file_name=f"{name}.pte", config=config)) return models, save_dir @@ -360,9 +344,7 @@ def from_pretrained( if local_files_only and not os.path.isdir(model_id): object_id = model_id.replace("/", "--") cached_model_dir = os.path.join(cache_dir, f"models--{object_id}") - refs_file = os.path.join( - os.path.join(cached_model_dir, "refs"), revision or "main" - ) + refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main") with open(refs_file) as f: _revision = f.read() model_dir = os.path.join(cached_model_dir, "snapshots", _revision) @@ -479,36 +461,28 @@ def __init__( ): super().__init__(models=models, config=config) if not hasattr(self, "encoder"): - raise AttributeError( - "Expected attribute 'encoder' not found in the instance." - ) + raise AttributeError("Expected attribute 'encoder' not found in the instance.") if not hasattr(self, "text_decoder"): - raise AttributeError( - "Expected attribute 'text_decoder' not found in the instance." - ) - metadata = self.text_decoder.method_names() + raise AttributeError("Expected attribute 'text_decoder' not found in the instance.") + metadata = self.decoder.method_names() if "use_kv_cache" in metadata: - self.use_kv_cache = self.text_decoder.run_method("use_kv_cache")[0] + self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] if "get_max_seq_len" in metadata: - self.max_cache_size = self.text_decoder.run_method("get_max_seq_len")[0] + self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0] if "get_max_batch_size" in metadata: - self.max_batch_size = self.text_decoder.run_method("get_max_batch_size")[0] + self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0] if "get_dtype" in metadata: - self.dtype = self.text_decoder.run_method("get_dtype")[0] + self.dtype = self.decoder.run_method("get_dtype")[0] if "get_bos_id" in metadata: - self.bos_token_id = self.text_decoder.run_method("get_bos_id")[0] + self.bos_token_id = self.decoder.run_method("get_bos_id")[0] if "get_eos_id" in metadata: - self.eos_token_id = self.text_decoder.run_method("get_eos_id")[0] + self.eos_token_id = self.decoder.run_method("get_eos_id")[0] if "get_vocab_size" in metadata: - self.vocab_size = self.text_decoder.run_method("get_vocab_size")[0] + self.vocab_size = self.decoder.run_method("get_vocab_size")[0] if "max_hidden_seq_length" in metadata: - self.max_hidden_seq_length = self.text_decoder.run_method( - "max_hidden_seq_length" - )[0] + self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0] if "decoder_start_token_id" in metadata: - self.decoder_start_token_id = self.text_decoder.run_method( - "decoder_start_token_id" - )[0] + self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0] def forward( self, @@ -525,9 +499,7 @@ def forward( self.stats.on_prompt_eval_end() result = ( - self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[ - 0 - ], + self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs, ) self.stats.on_model_execution_end() @@ -573,12 +545,8 @@ def generate( max_seq_len = self.max_cache_size if not hasattr(self, "decoder_start_token_id"): - raise AttributeError( - "'decoder_start_token_id' is missing in the metadata of the PTE." - ) - decoder_input_ids = torch.tensor( - [[self.decoder_start_token_id]], dtype=torch.long - ) + raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.") + decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long) encoder_input_ids = input_ids encoder_outputs = None generated_ids = [0] @@ -601,9 +569,7 @@ def generate( # Get next token next_token = torch.argmax(logits[:, -1, :], dim=-1).item() generated_ids.append(next_token) - self.stats.set_num_generated_tokens( - len(generated_ids) - 1 - ) # Don't count decoder_start_token + self.stats.set_num_generated_tokens(len(generated_ids) - 1) # Don't count decoder_start_token # Update input for next iteration decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) @@ -698,9 +664,7 @@ def __init__( ): super().__init__(models, config) if not hasattr(self, "model"): - raise AttributeError( - "Expected attribute 'model' not found in the instance." - ) + raise AttributeError("Expected attribute 'model' not found in the instance.") metadata = self.model.method_names() logging.debug(f"Load all static methods: {metadata}") if "use_kv_cache" in metadata: @@ -720,9 +684,7 @@ def __init__( if "get_vocab_size" in metadata: self.vocab_size = self.model.run_method("get_vocab_size")[0] if "use_sdpa_with_kv_cache" in metadata: - self.use_sdpa_with_kv_cache = self.model.run_method( - "use_sdpa_with_kv_cache" - )[0] + self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0] def forward( self, @@ -744,14 +706,8 @@ def forward( try: logits = self.model.forward((input_ids, cache_position))[0] except Exception as e: - shapes = { - name: val.shape - for name, val in locals().items() - if hasattr(val, "shape") - } - print( - f"Exception: {e}.\n{self.model.method_meta('forward')}\narg shapes: {shapes}" - ) + shapes = {name: val.shape for name, val in locals().items() if hasattr(val, "shape")} + print(f"Exception: {e}.\n{self.model.method_meta('forward')}\narg shapes: {shapes}") raise self.stats.on_model_execution_end() @@ -802,12 +758,8 @@ def generate( # The model is exported with dynamic shapes. Can support parallel prefill. self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor( - prompt_tokens, dtype=torch.long, device=self.device - ).unsqueeze(0), - cache_position=torch.arange( - len(prompt_tokens), dtype=torch.long, device=self.device - ), + input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device), ) self.stats.on_sampling_end() next_token = torch.argmax(logits, dim=-1)[0, -1].item() @@ -817,12 +769,8 @@ def generate( for i, prompt_token in enumerate(prompt_tokens): self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor( - [prompt_token], dtype=torch.long, device=self.device - ).unsqueeze(0), - cache_position=torch.tensor( - [i], dtype=torch.long, device=self.device - ), + input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor([i], dtype=torch.long, device=self.device), ) self.stats.on_sampling_end() next_token = torch.argmax(logits, dim=-1).item() @@ -834,9 +782,7 @@ def generate( while len(generated_tokens) < max_seq_len: self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor( - [next_token], dtype=torch.long, device=self.device - ).unsqueeze(0), + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), cache_position=torch.tensor( [pos_base + len(generated_tokens) - 1], dtype=torch.long, @@ -883,16 +829,11 @@ def text_generation( self.tokenizer = tokenizer # Sanity check - if ( - self.tokenizer.bos_token_id is not None - and self.tokenizer.bos_token_id != self.bos_token_id - ): + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: raise ValueError( f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." ) - if not verify_eos_tokens_in_pretrained_tokenizer( - self.eos_token_ids, self.tokenizer - ): + if not verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_ids, self.tokenizer): raise ValueError( f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}." ) @@ -949,9 +890,7 @@ def __init__( ): super().__init__(models, config) if not hasattr(self, "model"): - raise AttributeError( - "Expected attribute 'model' not found in the instance." - ) + raise AttributeError("Expected attribute 'model' not found in the instance.") metadata = self.model.method_names() logging.debug(f"Load all static methods: {metadata}") if "get_max_seq_len" in metadata: @@ -1027,9 +966,7 @@ def __init__( ): super().__init__(models, config) if not hasattr(self, "model"): - raise AttributeError( - "Expected attribute 'model' not found in the instance." - ) + raise AttributeError("Expected attribute 'model' not found in the instance.") metadata = self.model.method_names() logging.debug(f"Load all static methods: {metadata}") @@ -1091,37 +1028,27 @@ def __init__( config: "PretrainedConfig", ): super().__init__(models=models, config=config) - if not hasattr(self, "encoder"): - raise AttributeError( - "Expected attribute 'encoder' not found in the instance." - ) - if not hasattr(self, "text_decoder"): - raise AttributeError( - "Expected attribute 'decoder' not found in the instance." - ) - metadata = self.decoder.method_names() + if not hasattr(self, "model"): + raise AttributeError("Expected attribute 'model' not found in the instance.") + metadata = self.model.method_names() if "use_kv_cache" in metadata: - self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] + self.use_kv_cache = self.model.run_method("use_kv_cache")[0] if "get_max_seq_len" in metadata: - self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0] + self.max_cache_size = self.model.run_method("get_max_seq_len")[0] if "get_max_batch_size" in metadata: - self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0] + self.max_batch_size = self.model.run_method("get_max_batch_size")[0] if "get_dtype" in metadata: - self.dtype = self.decoder.run_method("get_dtype")[0] + self.dtype = self.model.run_method("get_dtype")[0] if "get_bos_id" in metadata: - self.bos_token_id = self.decoder.run_method("get_bos_id")[0] + self.bos_token_id = self.model.run_method("get_bos_id")[0] if "get_eos_id" in metadata: - self.eos_token_id = self.decoder.run_method("get_eos_id")[0] + self.eos_token_id = self.model.run_method("get_eos_id")[0] if "get_vocab_size" in metadata: - self.vocab_size = self.decoder.run_method("get_vocab_size")[0] + self.vocab_size = self.model.run_method("get_vocab_size")[0] if "max_hidden_seq_length" in metadata: - self.max_hidden_seq_length = self.decoder.run_method( - "max_hidden_seq_length" - )[0] + self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0] if "decoder_start_token_id" in metadata: - self.decoder_start_token_id = self.decoder.run_method( - "decoder_start_token_id" - )[0] + self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0] def forward( self, @@ -1133,13 +1060,11 @@ def forward( is_first_prediction = encoder_outputs is None self.stats.on_model_execution_start() if is_first_prediction: - encoder_outputs = self.encoder.forward((input_features,))[0] + encoder_outputs = self.model.run_method("encoder", (input_features,))[0] self.stats.on_prompt_eval_end() result = ( - self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[ - 0 - ], + self.model.run_method("text_decoder", (decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs, ) self.stats.on_model_execution_end() @@ -1182,12 +1107,8 @@ def generate( max_seq_len = self.max_cache_size if not hasattr(self, "decoder_start_token_id"): - raise AttributeError( - "'decoder_start_token_id' is missing in the metadata of the PTE." - ) - decoder_input_ids = torch.tensor( - [[self.decoder_start_token_id]], dtype=torch.long - ) + raise AttributeError("'decoder_start_token_id' is missing in the metadata of the PTE.") + decoder_input_ids = torch.tensor([[self.decoder_start_token_id]], dtype=torch.long) log_mel = input_features encoder_outputs = None generated_ids = [] @@ -1198,9 +1119,7 @@ def generate( # Run decoder for next token prediction cache_position = torch.tensor([i], dtype=torch.long) self.stats.on_sampling_begin() - logits, encoder_outputs = self.forward( - log_mel, decoder_input_ids, cache_position, encoder_outputs - ) + logits, encoder_outputs = self.forward(log_mel, decoder_input_ids, cache_position, encoder_outputs) self.stats.on_sampling_end() if not first_token_generated: self.stats.on_first_token() @@ -1208,9 +1127,7 @@ def generate( # Get next token next_token = torch.argmax(logits[:, -1, :], dim=-1).item() generated_ids.append(next_token) - self.stats.set_num_generated_tokens( - len(generated_ids) - 1 - ) # Don't count decoder_start_token + self.stats.set_num_generated_tokens(len(generated_ids) - 1) # Don't count decoder_start_token # Update input for next iteration decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) @@ -1371,9 +1288,7 @@ def generate( self.stats.on_sampling_begin() logits = self.forward( input_ids=prompt_tokens, - cache_position=torch.arange( - prompt_tokens.size(1), dtype=torch.long, device=self.device - ), + cache_position=torch.arange(prompt_tokens.size(1), dtype=torch.long, device=self.device), multimodal_features=multimodal_features, ) self.stats.on_sampling_end() @@ -1388,9 +1303,7 @@ def generate( while len(generated_tokens) + prompt_tokens.size(1) < max_seq_len: self.stats.on_sampling_begin() logits = self.forward( - input_ids=torch.tensor( - [next_token], dtype=torch.long, device=self.device - ).unsqueeze(0), + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), cache_position=torch.tensor( [pos_base + len(generated_tokens) + prompt_tokens.size(1) - 1], dtype=torch.long, @@ -1439,16 +1352,11 @@ def text_generation( self.tokenizer = tokenizer # Sanity check - if ( - self.tokenizer.bos_token_id is not None - and self.tokenizer.bos_token_id != self.bos_token_id - ): + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: raise ValueError( f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." ) - if isinstance( - self.tokenizer, PreTrainedTokenizer - ) and not verify_eos_tokens_in_pretrained_tokenizer( + if isinstance(self.tokenizer, PreTrainedTokenizer) and not verify_eos_tokens_in_pretrained_tokenizer( self.eos_token_id, self.tokenizer ): raise ValueError( diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py index f139815..a784a2c 100644 --- a/tests/models/test_modeling_whisper.py +++ b/tests/models/test_modeling_whisper.py @@ -49,8 +49,7 @@ def test_whisper_export_to_executorch(self): shell=True, check=True, ) - self.assertTrue(os.path.exists(f"{tempdir}/executorch/encoder.pte")) - self.assertTrue(os.path.exists(f"{tempdir}/executorch/decoder.pte")) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(f"{tempdir}/executorch") self._test_whisper_transcription(model_id, model) @@ -59,16 +58,17 @@ def _test_whisper_transcription(self, model_id: str, model: ExecuTorchModelForSp processor = AutoProcessor.from_pretrained(model_id) self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq) - self.assertTrue(hasattr(model, "encoder")) - self.assertIsInstance(model.encoder, ExecuTorchModule) - self.assertTrue(hasattr(model, "text_decoder")) - self.assertIsInstance(model.decoder, ExecuTorchModule) + self.assertTrue(hasattr(model, "model")) + self.assertIsInstance(model.model, ExecuTorchModule) dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") sample = dataset[0]["audio"] input_features = processor( - sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"] + sample["array"], + return_tensors="pt", + truncation=False, + sampling_rate=sample["sampling_rate"], ).input_features # Current implementation of the transcibe method accepts up to 30 seconds of audio, therefore I trim the audio here. input_features_trimmed = input_features[:, :, :3000].contiguous() From 33ac6d62e6b612bcd00b4f3e5552c9b794eb62dc Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:26:35 -0700 Subject: [PATCH 7/8] Fix numerical discrepancy --- optimum/commands/export/executorch.py | 4 +- optimum/exporters/executorch/integrations.py | 79 ++++++++++---------- optimum/exporters/executorch/tasks/asr.py | 8 +- 3 files changed, 45 insertions(+), 46 deletions(-) diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 12398da..e877114 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -185,9 +185,7 @@ def run(self): "--qlinear_packing_format can only be used when --device is set to CUDA (e.g., 'cuda', 'cuda:0', etc.)" ) if not self.args.qlinear or self.args.qlinear != "4w": - raise ValueError( - "--qlinear_packing_format can only be used when --qlinear is set to '4w'" - ) + raise ValueError("--qlinear_packing_format can only be used when --qlinear is set to '4w'") qlinear_encoder_packing_format = getattr(self.args, "qlinear_encoder_packing_format", None) if qlinear_encoder_packing_format: if not device or not device.startswith("cuda"): diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 269512a..b54ea66 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -29,7 +29,6 @@ T5ForConditionalGeneration, WhisperForConditionalGeneration, ) -from transformers.generation.configuration_utils import GenerationConfig from transformers.integrations.executorch import ( TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap, @@ -664,15 +663,31 @@ def __init__(self, model, max_static_cache_length, batch_size): self.proj_out = model.lm_head self.config = model.config - # Initialize static cache - self.static_cache = StaticCache( + # Initialize self attention cache + self.self_attention_cache = StaticCache( config=self.config, max_batch_size=batch_size, max_cache_len=max_static_cache_length, - device="cpu", + device=model.device, dtype=torch.float32, ) - self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) + head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) + num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) + self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model.device) + + # Initialize cross attention cache + self.dynamic_cache = DynamicCache(config=self.config) + self.cache = EncoderDecoderCache(self.self_attention_cache, self.dynamic_cache) + + # Register cache buffers to make them exportable. + # Cross attention cache buffer is not registered since it's not actually being used atm. + for i in range(len(self.self_attention_cache)): + self.register_buffer( + f"self_attention_key_cache_{i}", self.self_attention_cache.layers[i].keys, persistent=False + ) + self.register_buffer( + f"self_attention_value_cache_{i}", self.self_attention_cache.layers[i].values, persistent=False + ) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder @@ -694,26 +709,18 @@ def __init__( self, model: PreTrainedModel, batch_size=1, - max_hidden_seq_length=4096, - cache_implementation="static", - max_cache_length=1024, + max_seq_len=1024, + max_hidden_seq_len=4096, ): super().__init__() - self.full_model = model + self.model = model self.encoder = model.get_encoder() self.config = model.config - self.max_hidden_seq_length = max_hidden_seq_length - self.generation_config = GenerationConfig( - use_cache=True, - max_length=max_cache_length, - cache_implementation=cache_implementation, - cache_config={ - "batch_size": batch_size, - "max_cache_len": max_cache_length, - }, - ) - if isinstance(self.full_model, WhisperForConditionalGeneration): + self.max_hidden_seq_len = max_hidden_seq_len + self.batch_size = batch_size + self.max_seq_len = max_seq_len + if isinstance(self.model, WhisperForConditionalGeneration): self._processor = AutoProcessor.from_pretrained(model.config._name_or_path) self._expected_encoder_input_shape = torch.Size( ( @@ -722,14 +729,8 @@ def __init__( self._processor.feature_extractor.nb_max_frames, ) ) - additional_configs = {} - additional_configs["max_hidden_seq_length"] = max_hidden_seq_length # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods( - self.config, - self.generation_config, - **additional_configs, - ) + self.metadata = save_config_to_constant_methods(self.config, get_max_seq_len=max_seq_len) self.exported_encoder = None self.exported_decoder = None @@ -737,18 +738,18 @@ def _export_encoder(self, encoder_input_ids): wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() # Define dynamic sequence length for encoder - if isinstance(self.full_model, WhisperForConditionalGeneration): + if isinstance(self.model, WhisperForConditionalGeneration): assert ( encoder_input_ids.shape == self._expected_encoder_input_shape ), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}. For more infromation, please refer to the Whisper preprocessor config.""" dynamic_shapes = None - elif isinstance(self.full_model, T5ForConditionalGeneration): - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + elif isinstance(self.model, T5ForConditionalGeneration): + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len) dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}} else: raise ValueError( - f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export." + f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule encoder export." ) # Export the encoder @@ -764,19 +765,19 @@ def _export_encoder(self, encoder_input_ids): def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position): wrapped_decoder = ( Seq2SeqLMDecoderExportableModuleWithStaticCache( - model=self.full_model, - max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"), - batch_size=self.generation_config.cache_config.get("batch_size"), + model=self.model, + max_static_cache_length=self.max_seq_len, + batch_size=self.batch_size, ) .to("cpu") .eval() ) - if isinstance(self.full_model, WhisperForConditionalGeneration): + if isinstance(self.model, WhisperForConditionalGeneration): dynamic_shapes = None - elif isinstance(self.full_model, T5ForConditionalGeneration): + elif isinstance(self.model, T5ForConditionalGeneration): # Define dynamic dimension for encoder output sequence length - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len) dynamic_shapes = { "decoder_input_ids": None, "encoder_hidden_states": {1: encoder_seq_len_dim}, @@ -784,7 +785,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi } else: raise ValueError( - f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export." + f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule decoder export." ) # Export the decoder @@ -806,7 +807,7 @@ def export( cache_position=None, ) -> Dict[str, ExportedProgram]: if encoder_input_ids is None: - if isinstance(self.full_model, WhisperForConditionalGeneration): + if isinstance(self.model, WhisperForConditionalGeneration): example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape) else: example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) diff --git a/optimum/exporters/executorch/tasks/asr.py b/optimum/exporters/executorch/tasks/asr.py index bc20bdc..ccf1a7a 100644 --- a/optimum/exporters/executorch/tasks/asr.py +++ b/optimum/exporters/executorch/tasks/asr.py @@ -46,13 +46,13 @@ def load_seq2seq_speech_model(model_name_or_path: str, **kwargs) -> Seq2SeqLMExp """ device = "cpu" batch_size = 1 - max_hidden_seq_length = kwargs.get("max_hidden_seq_length", 4096) - max_cache_length = kwargs.get("max_cache_length", 1024) + max_hidden_seq_len = kwargs.get("max_hidden_seq_len", 4096) + max_seq_len = kwargs.get("max_seq_len", 1024) full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path).to(device).eval() return Seq2SeqLMExportableModule( full_model, batch_size=batch_size, - max_hidden_seq_length=max_hidden_seq_length, - max_cache_length=max_cache_length, + max_seq_len=max_seq_len, + max_hidden_seq_len=max_hidden_seq_len, ) From 87b83b12f54733e0eef6d5c482dc59e7e1841b69 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:35:45 -0700 Subject: [PATCH 8/8] Make portable recipe export into one pte with multiple methods --- .../exporters/executorch/recipes/portable.py | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/optimum/exporters/executorch/recipes/portable.py b/optimum/exporters/executorch/recipes/portable.py index f6faebb..4d4fa97 100644 --- a/optimum/exporters/executorch/recipes/portable.py +++ b/optimum/exporters/executorch/recipes/portable.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import Dict, Union from torch.export import ExportedProgram @@ -58,24 +57,22 @@ def _lower_to_executorch( exported_programs: Dict[str, ExportedProgram], metadata=None, ) -> Dict[str, ExecutorchProgram]: - et_progs = {} + # If just one exported program, the method name in the .pte for it should be "forward". + if len(exported_programs) == 1: + exported_programs = {"forward": next(iter(exported_programs.values()))} - for pte_name, exported_program in exported_programs.items(): - logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}") - et_progs[pte_name] = to_edge_transform_and_lower( - exported_program, - partitioner=[], - compile_config=EdgeCompileConfig( - _check_ir_validity=False, - _skip_dim_order=True, - ), - constant_methods=metadata, - transform_passes=[RemovePaddingIdxEmbeddingPass()], - ).to_executorch() - logging.debug( - f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}" - ) - return et_progs + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=[], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ).to_executorch() + pte_name = "model" + return {pte_name: et_prog} exported_progs = model.export()