diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 93a0034..f7339ec 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -40,7 +40,10 @@ from transformers.processing_utils import ProcessorMixin from transformers.utils import is_offline_mode -from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, + _load_for_executorch, +) from executorch.kernels import quantized # noqa from ..exporters import TasksManager @@ -460,7 +463,7 @@ def __init__( if not hasattr(self, "encoder"): raise AttributeError("Expected attribute 'encoder' not found in the instance.") if not hasattr(self, "text_decoder"): - raise AttributeError("Expected attribute 'decoder' not found in the instance.") + raise AttributeError("Expected attribute 'text_decoder' not found in the instance.") metadata = self.decoder.method_names() if "use_kv_cache" in metadata: self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] @@ -495,7 +498,10 @@ def forward( encoder_outputs = self.encoder.forward((input_ids,))[0] self.stats.on_prompt_eval_end() - result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + result = ( + self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], + encoder_outputs, + ) self.stats.on_model_execution_end() return result @@ -1022,29 +1028,27 @@ def __init__( config: "PretrainedConfig", ): super().__init__(models=models, config=config) - if not hasattr(self, "encoder"): - raise AttributeError("Expected attribute 'encoder' not found in the instance.") - if not hasattr(self, "text_decoder"): - raise AttributeError("Expected attribute 'decoder' not found in the instance.") - metadata = self.decoder.method_names() + if not hasattr(self, "model"): + raise AttributeError("Expected attribute 'model' not found in the instance.") + metadata = self.model.method_names() if "use_kv_cache" in metadata: - self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] + self.use_kv_cache = self.model.run_method("use_kv_cache")[0] if "get_max_seq_len" in metadata: - self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0] + self.max_cache_size = self.model.run_method("get_max_seq_len")[0] if "get_max_batch_size" in metadata: - self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0] + self.max_batch_size = self.model.run_method("get_max_batch_size")[0] if "get_dtype" in metadata: - self.dtype = self.decoder.run_method("get_dtype")[0] + self.dtype = self.model.run_method("get_dtype")[0] if "get_bos_id" in metadata: - self.bos_token_id = self.decoder.run_method("get_bos_id")[0] + self.bos_token_id = self.model.run_method("get_bos_id")[0] if "get_eos_id" in metadata: - self.eos_token_id = self.decoder.run_method("get_eos_id")[0] + self.eos_token_id = self.model.run_method("get_eos_id")[0] if "get_vocab_size" in metadata: - self.vocab_size = self.decoder.run_method("get_vocab_size")[0] + self.vocab_size = self.model.run_method("get_vocab_size")[0] if "max_hidden_seq_length" in metadata: - self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0] + self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0] if "decoder_start_token_id" in metadata: - self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0] + self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0] def forward( self, @@ -1056,10 +1060,13 @@ def forward( is_first_prediction = encoder_outputs is None self.stats.on_model_execution_start() if is_first_prediction: - encoder_outputs = self.encoder.forward((input_features,))[0] + encoder_outputs = self.model.run_method("encoder", (input_features,))[0] self.stats.on_prompt_eval_end() - result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + result = ( + self.model.run_method("text_decoder", (decoder_input_ids, encoder_outputs, cache_position))[0], + encoder_outputs, + ) self.stats.on_model_execution_end() return result diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index eec3d33..ebe396b 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -22,13 +22,18 @@ from transformers import ( AutoConfig, AutoProcessor, + DynamicCache, + EncoderDecoderCache, PreTrainedModel, StaticCache, T5ForConditionalGeneration, WhisperForConditionalGeneration, ) from transformers.generation.configuration_utils import GenerationConfig -from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap +from transformers.integrations.executorch import ( + TorchExportableModuleForDecoderOnlyLM, + sdpa_mask_without_vmap, +) from transformers.masking_utils import AttentionMaskInterface from transformers.modeling_utils import AttentionInterface @@ -50,7 +55,10 @@ def prepare_export_inputs(self): { "role": "user", "content": [ - {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + { + "type": "image", + "url": "https://llava-vl.github.io/static/images/view.jpg", + }, ], }, ] @@ -330,7 +338,10 @@ def export( mutated_gm, args=(), # For the ET runner, it's important to have cache position as the 2nd arg. - kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position}, + kwargs={ + "inputs_embeds": inputs_embeds, + "cache_position": cache_position, + }, dynamic_shapes=dynamic_shapes, strict=True, ) @@ -390,7 +401,12 @@ class CausalLMExportableModule(torch.nn.Module): """ def __init__( - self, model, max_seq_len=2048, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False + self, + model, + max_seq_len=2048, + use_custom_kv_cache=False, + use_custom_sdpa=False, + disable_dynamic_shapes=False, ): super().__init__() self.model = model @@ -487,7 +503,10 @@ def export( with torch.no_grad(): exported_program = exportable_module.export( - input_ids=input_ids, cache_position=cache_position, dynamic_shapes=dynamic_shapes, strict=strict + input_ids=input_ids, + cache_position=cache_position, + dynamic_shapes=dynamic_shapes, + strict=strict, ) # Apply RemoveTransposes pass to remove # any back-to-back transpose ops that are not needed @@ -643,18 +662,14 @@ def __init__(self, model, max_static_cache_length, batch_size): device="cpu", dtype=torch.float32, ) - - # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder outputs = self.decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_hidden_states, - past_key_values=self.static_cache, + past_key_values=self.cache, use_cache=True, cache_position=cache_position, ) diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py index f139815..a784a2c 100644 --- a/tests/models/test_modeling_whisper.py +++ b/tests/models/test_modeling_whisper.py @@ -49,8 +49,7 @@ def test_whisper_export_to_executorch(self): shell=True, check=True, ) - self.assertTrue(os.path.exists(f"{tempdir}/executorch/encoder.pte")) - self.assertTrue(os.path.exists(f"{tempdir}/executorch/decoder.pte")) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(f"{tempdir}/executorch") self._test_whisper_transcription(model_id, model) @@ -59,16 +58,17 @@ def _test_whisper_transcription(self, model_id: str, model: ExecuTorchModelForSp processor = AutoProcessor.from_pretrained(model_id) self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq) - self.assertTrue(hasattr(model, "encoder")) - self.assertIsInstance(model.encoder, ExecuTorchModule) - self.assertTrue(hasattr(model, "text_decoder")) - self.assertIsInstance(model.decoder, ExecuTorchModule) + self.assertTrue(hasattr(model, "model")) + self.assertIsInstance(model.model, ExecuTorchModule) dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") sample = dataset[0]["audio"] input_features = processor( - sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"] + sample["array"], + return_tensors="pt", + truncation=False, + sampling_rate=sample["sampling_rate"], ).input_features # Current implementation of the transcibe method accepts up to 30 seconds of audio, therefore I trim the audio here. input_features_trimmed = input_features[:, :, :3000].contiguous()