Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions optimum/exporters/executorch/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import logging
import os
import executorch
from pathlib import Path
from typing import Union

from transformers.modeling_utils import AttentionInterface
from executorch.backends.xnnpack import get_xnnpack_recipe

from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward

Expand Down Expand Up @@ -64,18 +66,14 @@ def export_to_executorch(
- The exported model is stored in the specified output directory with the fixed filename `model.pte`.
- The resulting ExecuTorch program is serialized and saved to the output directory.
"""

# Dynamically discover and import registered recipes
discover_recipes()

# Export and lower the model to ExecuTorch with the recipe
try:
recipe_func = recipe_registry.get(recipe)
except KeyError as e:
raise RuntimeError(f"The recipe '{recipe}' isn't registered. Detailed error: {e}")

executorch_progs = recipe_func(model, **kwargs)


executorch_progs = {}
models = model.get_exportable_model_and_inputs()
for name, model_dict in models.items():
eager_model = model_dict["model"]
session = executorch.export.export(eager_model, [model_dict["inputs"]], get_xnnpack_recipe("FP32_CPU_ACCELERATED_RECIPE"),dynamic_shapes=model_dict.get("dynamic_shapes"), constant_methods= model.metadata)
executorch_progs[name] = session.get_executorch_program_manager()

for name, prog in executorch_progs.items():
full_path = os.path.join(f"{output_dir}", f"{name}.pte")
with open(full_path, "wb") as f:
Expand Down
24 changes: 16 additions & 8 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers.generation.configuration_utils import GenerationConfig

from optimum.utils.import_utils import is_transformers_version

from transformers.integrations.executorch import TorchExportableModuleWithStaticCache
from .utils import save_config_to_constant_methods


Expand All @@ -43,6 +43,20 @@ def __init__(self, model):
self.config = model.config
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)

def get_exportable_model_and_inputs(self, input_ids=None, cache_position=None):
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)

if is_transformers_version(">=", "4.52.0.dev0"):
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
)
max_batch_size = 1
max_cache_len = 4094
return {"model":{"model": TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len), "inputs": (example_input_ids,example_cache_position), "dynamic_shapes":None}}
else:
return {"model":{"model": TorchExportableModuleWithStaticCache(self.model), "inputs": (example_input_ids,example_cache_position), "dynamic_shapes":None}}

def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgram]:
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
Expand Down Expand Up @@ -148,13 +162,7 @@ def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgr
# Export the model with dynamic dimensions
with torch.no_grad():
return {
"model": torch.export.export(
self.model,
args=(dummy_input_ids,),
kwargs={"attention_mask": dummy_attention_mask},
dynamic_shapes=dynamic_shapes,
strict=True,
)
"model": (self.model,(dummy_input_ids,{"attention_mask": dummy_attention_mask}),dynamic_shapes,)
}


Expand Down
23 changes: 17 additions & 6 deletions optimum/exporters/executorch/recipes/xnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from tabulate import tabulate
from torch.export import ExportedProgram

from executorch import version as executorch_version
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import (
Expand All @@ -35,7 +34,8 @@
Seq2SeqLMExportableModule,
)
from ..recipe_registry import register_recipe

from executorch.backends.xnnpack import get_xnnpack_recipe
from executorch.export import export

@register_recipe("xnnpack")
def export_to_executorch_with_xnnpack(
Expand Down Expand Up @@ -67,8 +67,7 @@ def _lower_to_executorch(
backend_config_dict = {
"extract_delegate_segments": True,
}
if parse(executorch_version.__version__).base_version > "0.6.0":
backend_config_dict["do_quant_fusion_and_const_prop"] = True
backend_config_dict["do_quant_fusion_and_const_prop"] = True

for pte_name, exported_program in exported_programs.items():
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
Expand All @@ -93,7 +92,8 @@ def _lower_to_executorch(
return et_progs

exported_progs = model.export()


"""
if model.config._attn_implementation == "custom_sdpa":
# Sanity check to make sure the exported program contains the custom sdpa operator.
if not any(
Expand All @@ -102,5 +102,16 @@ def _lower_to_executorch(
for node in exported_program.graph_module.graph.nodes
):
raise ValueError("'custom_sdpa' not found in the graph.")
"""

executorch_programs = {}
for name, e_model in exported_progs.items():
eager_model = e_model[0]
example_inputs = [e_model[1]]
dynamic_shapes = e_model[2]

session = export(eager_model, example_inputs, get_xnnpack_recipe("8A4W_CPU_ACCELERATED_RECIPE"), dynamic_shapes = dynamic_shapes, constant_methods=model.metadata)
executorch_programs[name] = session.get_executorch_program_manager()
session.print_delegation_info()

return _lower_to_executorch(exported_progs, model.metadata)
return executorch_programs
Loading