diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py index d659fa4a..07f1ee82 100644 --- a/optimum/exporters/executorch/convert.py +++ b/optimum/exporters/executorch/convert.py @@ -16,10 +16,12 @@ import logging import os +import executorch from pathlib import Path from typing import Union from transformers.modeling_utils import AttentionInterface +from executorch.backends.xnnpack import get_xnnpack_recipe from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward @@ -64,18 +66,14 @@ def export_to_executorch( - The exported model is stored in the specified output directory with the fixed filename `model.pte`. - The resulting ExecuTorch program is serialized and saved to the output directory. """ - - # Dynamically discover and import registered recipes - discover_recipes() - - # Export and lower the model to ExecuTorch with the recipe - try: - recipe_func = recipe_registry.get(recipe) - except KeyError as e: - raise RuntimeError(f"The recipe '{recipe}' isn't registered. Detailed error: {e}") - - executorch_progs = recipe_func(model, **kwargs) - + + executorch_progs = {} + models = model.get_exportable_model_and_inputs() + for name, model_dict in models.items(): + eager_model = model_dict["model"] + session = executorch.export.export(eager_model, [model_dict["inputs"]], get_xnnpack_recipe("FP32_CPU_ACCELERATED_RECIPE"),dynamic_shapes=model_dict.get("dynamic_shapes"), constant_methods= model.metadata) + executorch_progs[name] = session.get_executorch_program_manager() + for name, prog in executorch_progs.items(): full_path = os.path.join(f"{output_dir}", f"{name}.pte") with open(full_path, "wb") as f: diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 88365d23..66065486 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -27,7 +27,7 @@ from transformers.generation.configuration_utils import GenerationConfig from optimum.utils.import_utils import is_transformers_version - +from transformers.integrations.executorch import TorchExportableModuleWithStaticCache from .utils import save_config_to_constant_methods @@ -43,6 +43,20 @@ def __init__(self, model): self.config = model.config self.metadata = save_config_to_constant_methods(model.config, model.generation_config) + def get_exportable_model_and_inputs(self, input_ids=None, cache_position=None): + example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) + example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) + + if is_transformers_version(">=", "4.52.0.dev0"): + from transformers.integrations.executorch import ( + TorchExportableModuleForDecoderOnlyLM, + ) + max_batch_size = 1 + max_cache_len = 4094 + return {"model":{"model": TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len), "inputs": (example_input_ids,example_cache_position), "dynamic_shapes":None}} + else: + return {"model":{"model": TorchExportableModuleWithStaticCache(self.model), "inputs": (example_input_ids,example_cache_position), "dynamic_shapes":None}} + def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgram]: example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) @@ -148,13 +162,7 @@ def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgr # Export the model with dynamic dimensions with torch.no_grad(): return { - "model": torch.export.export( - self.model, - args=(dummy_input_ids,), - kwargs={"attention_mask": dummy_attention_mask}, - dynamic_shapes=dynamic_shapes, - strict=True, - ) + "model": (self.model,(dummy_input_ids,{"attention_mask": dummy_attention_mask}),dynamic_shapes,) } diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index ce0220dd..f6409d07 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -19,7 +19,6 @@ from tabulate import tabulate from torch.export import ExportedProgram -from executorch import version as executorch_version from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import ( @@ -35,7 +34,8 @@ Seq2SeqLMExportableModule, ) from ..recipe_registry import register_recipe - +from executorch.backends.xnnpack import get_xnnpack_recipe +from executorch.export import export @register_recipe("xnnpack") def export_to_executorch_with_xnnpack( @@ -67,8 +67,7 @@ def _lower_to_executorch( backend_config_dict = { "extract_delegate_segments": True, } - if parse(executorch_version.__version__).base_version > "0.6.0": - backend_config_dict["do_quant_fusion_and_const_prop"] = True + backend_config_dict["do_quant_fusion_and_const_prop"] = True for pte_name, exported_program in exported_programs.items(): logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}") @@ -93,7 +92,8 @@ def _lower_to_executorch( return et_progs exported_progs = model.export() - + + """ if model.config._attn_implementation == "custom_sdpa": # Sanity check to make sure the exported program contains the custom sdpa operator. if not any( @@ -102,5 +102,16 @@ def _lower_to_executorch( for node in exported_program.graph_module.graph.nodes ): raise ValueError("'custom_sdpa' not found in the graph.") + """ + + executorch_programs = {} + for name, e_model in exported_progs.items(): + eager_model = e_model[0] + example_inputs = [e_model[1]] + dynamic_shapes = e_model[2] + + session = export(eager_model, example_inputs, get_xnnpack_recipe("8A4W_CPU_ACCELERATED_RECIPE"), dynamic_shapes = dynamic_shapes, constant_methods=model.metadata) + executorch_programs[name] = session.get_executorch_program_manager() + session.print_delegation_info() - return _lower_to_executorch(exported_progs, model.metadata) + return executorch_programs