Skip to content

Commit 8e48a75

Browse files
committed
Export recipes integration in export_llama
1 parent 04bf288 commit 8e48a75

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from pathlib import Path
2222
from typing import Callable, List, Optional, Union
2323

24+
import executorch
25+
2426
import pkg_resources
2527
import torch
2628

@@ -50,6 +52,7 @@
5052
get_qnn_quantizer,
5153
get_vulkan_quantizer,
5254
)
55+
from executorch.extension.llm.export.recipes import get_llm_recipe
5356
from executorch.util.activation_memory_profiler import generate_memory_trace
5457

5558
from ..model_factory import EagerModelFactory
@@ -546,6 +549,13 @@ def build_args_parser() -> argparse.ArgumentParser:
546549
action="store_true",
547550
help="If true, stops right after torch.export() and saves the exported model.",
548551
)
552+
553+
parser.add_argument(
554+
"--recipe_flow",
555+
default=False,
556+
action="store_true",
557+
help="Experimental feature, this will use the executorch.export + recipe based flow",
558+
)
549559
return parser
550560

551561

@@ -610,6 +620,9 @@ def export_llama(args) -> str:
610620
"Please run `pip install snakeviz` to install required dependencies for cProfiler flamegraph."
611621
)
612622
return ""
623+
elif args.recipe_flow:
624+
filename = _recipe_based_export_llama(args)
625+
return filename
613626
else:
614627
builder = _export_llama(args)
615628
assert (
@@ -1102,6 +1115,24 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
11021115
return builder
11031116

11041117

1118+
def _recipe_based_export_llama(args) -> str: # noqa: C901
1119+
_validate_args(args)
1120+
assert args.xnnpack, "Recipe based flow only supports xnnpack backend currently."
1121+
1122+
builder = _prepare_for_llama_export(args)
1123+
session = executorch.export.export(
1124+
builder.model,
1125+
[builder.example_inputs],
1126+
get_llm_recipe(args),
1127+
dynamic_shapes=builder._get_dynamic_shape(),
1128+
)
1129+
1130+
session.print_delegation_info()
1131+
session.save_to_pte(builder.modelname)
1132+
1133+
return builder.modelname
1134+
1135+
11051136
def _load_llama_model_metadata(
11061137
weight_type: WeightType,
11071138
use_kv_cache: bool,

extension/llm/export/export_passes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from executorch.exir.pass_base import ExportPass
6+
from executorch.exir.program._program import _update_exported_program_graph_module
67
from torch._subclasses import FakeTensor
78
from torch.fx.passes.infra.pass_base import PassResult
89

@@ -99,6 +100,14 @@ def call(self, graph_module: torch.fx.GraphModule):
99100
return PassResult(graph_module, graph_changed)
100101

101102

103+
def remove_redundant_transposes(
104+
ep: torch.export.ExportedProgram,
105+
) -> torch.export.ExportedProgram:
106+
res = RemoveRedundantTransposes()(ep.graph_module)
107+
assert res is not None
108+
return _update_exported_program_graph_module(ep, res.graph_module)
109+
110+
102111
class ReplaceSDPAWithCustomSDPAPass(ExportPass):
103112
"""
104113
This pass replaces aten.scaled_dot_product_attention.default with llama.custom_sdpa.default.

extension/llm/export/recipes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from executorch.export.recipe import ExportRecipe, QuantizationRecipe
2+
from executorch.exir import EdgeCompileConfig
3+
from executorch.extension.llm.export.quantizer_lib import get_quantizer_and_quant_params
4+
from executorch.extension.llm.export.export_passes import remove_redundant_transposes
5+
from executorch.extension.llm.export.partitioner_lib import (
6+
get_coreml_partitioner,
7+
get_mps_partitioner,
8+
get_qnn_partitioner,
9+
get_vulkan_partitioner,
10+
get_xnnpack_partitioner,
11+
)
12+
13+
def get_llm_recipe(args) -> ExportRecipe:
14+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
15+
16+
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
17+
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
18+
args.xnnpack = True
19+
20+
quant_recipe = QuantizationRecipe(
21+
quantizers=quantizers,
22+
)
23+
24+
partitioners = []
25+
if args.xnnpack:
26+
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
27+
if args.xnnpack_extended_ops:
28+
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=False))
29+
30+
return ExportRecipe(
31+
quantization_recipe=quant_recipe,
32+
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
33+
pre_edge_transform_passes=[remove_redundant_transposes],
34+
edge_transform_passes=[],
35+
partitioners=partitioners,
36+
)

0 commit comments

Comments
 (0)