diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 12398da..e877114 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -185,9 +185,7 @@ def run(self): "--qlinear_packing_format can only be used when --device is set to CUDA (e.g., 'cuda', 'cuda:0', etc.)" ) if not self.args.qlinear or self.args.qlinear != "4w": - raise ValueError( - "--qlinear_packing_format can only be used when --qlinear is set to '4w'" - ) + raise ValueError("--qlinear_packing_format can only be used when --qlinear is set to '4w'") qlinear_encoder_packing_format = getattr(self.args, "qlinear_encoder_packing_format", None) if qlinear_encoder_packing_format: if not device or not device.startswith("cuda"): diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 93a0034..a14b97c 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -40,7 +40,10 @@ from transformers.processing_utils import ProcessorMixin from transformers.utils import is_offline_mode -from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, + _load_for_executorch, +) from executorch.kernels import quantized # noqa from ..exporters import TasksManager @@ -460,7 +463,7 @@ def __init__( if not hasattr(self, "encoder"): raise AttributeError("Expected attribute 'encoder' not found in the instance.") if not hasattr(self, "text_decoder"): - raise AttributeError("Expected attribute 'decoder' not found in the instance.") + raise AttributeError("Expected attribute 'text_decoder' not found in the instance.") metadata = self.decoder.method_names() if "use_kv_cache" in metadata: self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] @@ -495,7 +498,10 @@ def forward( encoder_outputs = self.encoder.forward((input_ids,))[0] self.stats.on_prompt_eval_end() - result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + result = ( + self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], + encoder_outputs, + ) self.stats.on_model_execution_end() return result @@ -1022,29 +1028,27 @@ def __init__( config: "PretrainedConfig", ): super().__init__(models=models, config=config) - if not hasattr(self, "encoder"): - raise AttributeError("Expected attribute 'encoder' not found in the instance.") - if not hasattr(self, "text_decoder"): - raise AttributeError("Expected attribute 'decoder' not found in the instance.") - metadata = self.decoder.method_names() + if not hasattr(self, "model"): + raise AttributeError("Expected attribute 'model' not found in the instance.") + metadata = self.model.method_names() if "use_kv_cache" in metadata: - self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0] + self.use_kv_cache = self.model.run_method("use_kv_cache")[0] if "get_max_seq_len" in metadata: - self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0] + self.max_cache_size = self.model.run_method("get_max_seq_len")[0] if "get_max_batch_size" in metadata: - self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0] + self.max_batch_size = self.model.run_method("get_max_batch_size")[0] if "get_dtype" in metadata: - self.dtype = self.decoder.run_method("get_dtype")[0] + self.dtype = self.model.run_method("get_dtype")[0] if "get_bos_id" in metadata: - self.bos_token_id = self.decoder.run_method("get_bos_id")[0] + self.bos_token_id = self.model.run_method("get_bos_id")[0] if "get_eos_id" in metadata: - self.eos_token_id = self.decoder.run_method("get_eos_id")[0] + self.eos_token_id = self.model.run_method("get_eos_id")[0] if "get_vocab_size" in metadata: - self.vocab_size = self.decoder.run_method("get_vocab_size")[0] + self.vocab_size = self.model.run_method("get_vocab_size")[0] if "max_hidden_seq_length" in metadata: - self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0] + self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0] if "decoder_start_token_id" in metadata: - self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0] + self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0] def forward( self, @@ -1056,10 +1060,13 @@ def forward( is_first_prediction = encoder_outputs is None self.stats.on_model_execution_start() if is_first_prediction: - encoder_outputs = self.encoder.forward((input_features,))[0] + encoder_outputs = self.model.run_method("encoder", (input_features,))[0] self.stats.on_prompt_eval_end() - result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs) + result = ( + self.model.run_method("text_decoder", (decoder_input_ids, encoder_outputs, cache_position))[0], + encoder_outputs, + ) self.stats.on_model_execution_end() return result @@ -1117,6 +1124,7 @@ def generate( if not first_token_generated: self.stats.on_first_token() first_token_generated = True + # Get next token next_token = torch.argmax(logits[:, -1, :], dim=-1).item() generated_ids.append(next_token) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 77c4ef0..b54ea66 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -22,13 +22,17 @@ from transformers import ( AutoConfig, AutoProcessor, + DynamicCache, + EncoderDecoderCache, PreTrainedModel, StaticCache, T5ForConditionalGeneration, WhisperForConditionalGeneration, ) -from transformers.generation.configuration_utils import GenerationConfig -from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap +from transformers.integrations.executorch import ( + TorchExportableModuleForDecoderOnlyLM, + sdpa_mask_without_vmap, +) from transformers.masking_utils import AttentionMaskInterface from transformers.modeling_utils import AttentionInterface @@ -50,7 +54,10 @@ def prepare_export_inputs(self): { "role": "user", "content": [ - {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + { + "type": "image", + "url": "https://llava-vl.github.io/static/images/view.jpg", + }, ], }, ] @@ -337,7 +344,10 @@ def export( mutated_gm, args=(), # For the ET runner, it's important to have cache position as the 2nd arg. - kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position}, + kwargs={ + "inputs_embeds": inputs_embeds, + "cache_position": cache_position, + }, dynamic_shapes=dynamic_shapes, strict=True, ) @@ -400,7 +410,12 @@ class CausalLMExportableModule(torch.nn.Module): """ def __init__( - self, model, max_seq_len=2048, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False + self, + model, + max_seq_len=2048, + use_custom_kv_cache=False, + use_custom_sdpa=False, + disable_dynamic_shapes=False, ): super().__init__() self.model = model @@ -497,7 +512,10 @@ def export( with torch.no_grad(): exported_program = exportable_module.export( - input_ids=input_ids, cache_position=cache_position, dynamic_shapes=dynamic_shapes, strict=strict + input_ids=input_ids, + cache_position=cache_position, + dynamic_shapes=dynamic_shapes, + strict=strict, ) # Apply RemoveTransposes pass to remove # any back-to-back transpose ops that are not needed @@ -645,26 +663,38 @@ def __init__(self, model, max_static_cache_length, batch_size): self.proj_out = model.lm_head self.config = model.config - # Initialize static cache - self.static_cache = StaticCache( + # Initialize self attention cache + self.self_attention_cache = StaticCache( config=self.config, max_batch_size=batch_size, max_cache_len=max_static_cache_length, - device="cpu", + device=model.device, dtype=torch.float32, ) - - # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) + num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) + self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model.device) + + # Initialize cross attention cache + self.dynamic_cache = DynamicCache(config=self.config) + self.cache = EncoderDecoderCache(self.self_attention_cache, self.dynamic_cache) + + # Register cache buffers to make them exportable. + # Cross attention cache buffer is not registered since it's not actually being used atm. + for i in range(len(self.self_attention_cache)): + self.register_buffer( + f"self_attention_key_cache_{i}", self.self_attention_cache.layers[i].keys, persistent=False + ) + self.register_buffer( + f"self_attention_value_cache_{i}", self.self_attention_cache.layers[i].values, persistent=False + ) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder outputs = self.decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_hidden_states, - past_key_values=self.static_cache, + past_key_values=self.cache, use_cache=True, cache_position=cache_position, ) @@ -679,26 +709,18 @@ def __init__( self, model: PreTrainedModel, batch_size=1, - max_hidden_seq_length=4096, - cache_implementation="static", - max_cache_length=1024, + max_seq_len=1024, + max_hidden_seq_len=4096, ): super().__init__() - self.full_model = model + self.model = model self.encoder = model.get_encoder() self.config = model.config - self.max_hidden_seq_length = max_hidden_seq_length - self.generation_config = GenerationConfig( - use_cache=True, - max_length=max_cache_length, - cache_implementation=cache_implementation, - cache_config={ - "batch_size": batch_size, - "max_cache_len": max_cache_length, - }, - ) - if isinstance(self.full_model, WhisperForConditionalGeneration): + self.max_hidden_seq_len = max_hidden_seq_len + self.batch_size = batch_size + self.max_seq_len = max_seq_len + if isinstance(self.model, WhisperForConditionalGeneration): self._processor = AutoProcessor.from_pretrained(model.config._name_or_path) self._expected_encoder_input_shape = torch.Size( ( @@ -707,14 +729,8 @@ def __init__( self._processor.feature_extractor.nb_max_frames, ) ) - additional_configs = {} - additional_configs["max_hidden_seq_length"] = max_hidden_seq_length # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods( - self.config, - self.generation_config, - **additional_configs, - ) + self.metadata = save_config_to_constant_methods(self.config, get_max_seq_len=max_seq_len) self.exported_encoder = None self.exported_decoder = None @@ -722,18 +738,18 @@ def _export_encoder(self, encoder_input_ids): wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() # Define dynamic sequence length for encoder - if isinstance(self.full_model, WhisperForConditionalGeneration): + if isinstance(self.model, WhisperForConditionalGeneration): assert ( encoder_input_ids.shape == self._expected_encoder_input_shape ), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}. For more infromation, please refer to the Whisper preprocessor config.""" dynamic_shapes = None - elif isinstance(self.full_model, T5ForConditionalGeneration): - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + elif isinstance(self.model, T5ForConditionalGeneration): + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len) dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}} else: raise ValueError( - f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export." + f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule encoder export." ) # Export the encoder @@ -749,19 +765,19 @@ def _export_encoder(self, encoder_input_ids): def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position): wrapped_decoder = ( Seq2SeqLMDecoderExportableModuleWithStaticCache( - model=self.full_model, - max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"), - batch_size=self.generation_config.cache_config.get("batch_size"), + model=self.model, + max_static_cache_length=self.max_seq_len, + batch_size=self.batch_size, ) .to("cpu") .eval() ) - if isinstance(self.full_model, WhisperForConditionalGeneration): + if isinstance(self.model, WhisperForConditionalGeneration): dynamic_shapes = None - elif isinstance(self.full_model, T5ForConditionalGeneration): + elif isinstance(self.model, T5ForConditionalGeneration): # Define dynamic dimension for encoder output sequence length - encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len) dynamic_shapes = { "decoder_input_ids": None, "encoder_hidden_states": {1: encoder_seq_len_dim}, @@ -769,7 +785,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi } else: raise ValueError( - f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export." + f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule decoder export." ) # Export the decoder @@ -791,7 +807,7 @@ def export( cache_position=None, ) -> Dict[str, ExportedProgram]: if encoder_input_ids is None: - if isinstance(self.full_model, WhisperForConditionalGeneration): + if isinstance(self.model, WhisperForConditionalGeneration): example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape) else: example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) diff --git a/optimum/exporters/executorch/recipes/portable.py b/optimum/exporters/executorch/recipes/portable.py index f6faebb..4d4fa97 100644 --- a/optimum/exporters/executorch/recipes/portable.py +++ b/optimum/exporters/executorch/recipes/portable.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import Dict, Union from torch.export import ExportedProgram @@ -58,24 +57,22 @@ def _lower_to_executorch( exported_programs: Dict[str, ExportedProgram], metadata=None, ) -> Dict[str, ExecutorchProgram]: - et_progs = {} + # If just one exported program, the method name in the .pte for it should be "forward". + if len(exported_programs) == 1: + exported_programs = {"forward": next(iter(exported_programs.values()))} - for pte_name, exported_program in exported_programs.items(): - logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}") - et_progs[pte_name] = to_edge_transform_and_lower( - exported_program, - partitioner=[], - compile_config=EdgeCompileConfig( - _check_ir_validity=False, - _skip_dim_order=True, - ), - constant_methods=metadata, - transform_passes=[RemovePaddingIdxEmbeddingPass()], - ).to_executorch() - logging.debug( - f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}" - ) - return et_progs + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=[], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ).to_executorch() + pte_name = "model" + return {pte_name: et_prog} exported_progs = model.export() diff --git a/optimum/exporters/executorch/tasks/asr.py b/optimum/exporters/executorch/tasks/asr.py index bc20bdc..ccf1a7a 100644 --- a/optimum/exporters/executorch/tasks/asr.py +++ b/optimum/exporters/executorch/tasks/asr.py @@ -46,13 +46,13 @@ def load_seq2seq_speech_model(model_name_or_path: str, **kwargs) -> Seq2SeqLMExp """ device = "cpu" batch_size = 1 - max_hidden_seq_length = kwargs.get("max_hidden_seq_length", 4096) - max_cache_length = kwargs.get("max_cache_length", 1024) + max_hidden_seq_len = kwargs.get("max_hidden_seq_len", 4096) + max_seq_len = kwargs.get("max_seq_len", 1024) full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path).to(device).eval() return Seq2SeqLMExportableModule( full_model, batch_size=batch_size, - max_hidden_seq_length=max_hidden_seq_length, - max_cache_length=max_cache_length, + max_seq_len=max_seq_len, + max_hidden_seq_len=max_hidden_seq_len, ) diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py index f139815..a784a2c 100644 --- a/tests/models/test_modeling_whisper.py +++ b/tests/models/test_modeling_whisper.py @@ -49,8 +49,7 @@ def test_whisper_export_to_executorch(self): shell=True, check=True, ) - self.assertTrue(os.path.exists(f"{tempdir}/executorch/encoder.pte")) - self.assertTrue(os.path.exists(f"{tempdir}/executorch/decoder.pte")) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(f"{tempdir}/executorch") self._test_whisper_transcription(model_id, model) @@ -59,16 +58,17 @@ def _test_whisper_transcription(self, model_id: str, model: ExecuTorchModelForSp processor = AutoProcessor.from_pretrained(model_id) self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq) - self.assertTrue(hasattr(model, "encoder")) - self.assertIsInstance(model.encoder, ExecuTorchModule) - self.assertTrue(hasattr(model, "text_decoder")) - self.assertIsInstance(model.decoder, ExecuTorchModule) + self.assertTrue(hasattr(model, "model")) + self.assertIsInstance(model.model, ExecuTorchModule) dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") sample = dataset[0]["audio"] input_features = processor( - sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"] + sample["array"], + return_tensors="pt", + truncation=False, + sampling_rate=sample["sampling_rate"], ).input_features # Current implementation of the transcibe method accepts up to 30 seconds of audio, therefore I trim the audio here. input_features_trimmed = input_features[:, :, :3000].contiguous()