Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,16 +879,21 @@ def _to_edge_and_lower_llama_xnnpack(
builder_exported.generate_etrecord = True

builder = builder_exported.pt2e_quantize(quantizers)

# re-export required here?
# run decomps
builder = builder.run_decompositions()
# tag EP
if gen_tag_fn is not None:
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass_unlifted,
delegate_external_constants_pass_lifted,
)

assert (
builder_exported.pre_autograd_graph_module is not None
), "pre_autograd_graph_module shouldn't be None here"
delegate_external_constants_pass_unlifted(
module=builder_exported.pre_autograd_graph_module,
delegate_external_constants_pass_lifted(
ep=builder_exported.exported_module,
gen_tag_fn=gen_tag_fn,
)

Expand Down
20 changes: 20 additions & 0 deletions exir/passes/external_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,23 @@ def delegate_external_constants_pass_unlifted(
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(module, mutated)

def delegate_external_constants_pass_lifted(
ep: ExportedProgram,
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
) -> PassResult:
"""
Tag constants in an ExportedProgram for external storage.
Works on the lifted graph directly, no re-export needed.
"""
mutated = False
gm = ep.graph_module

for node in gm.graph.nodes:
if node.op == "placeholder" and is_param_node(ep, node):
if gen_tag_fn is not None:
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True

return PassResult(gm, mutated)
39 changes: 20 additions & 19 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
from executorch.exir import EdgeProgramManager, to_edge, to_edge_transform_and_lower
from executorch.exir.backend.partitioner import Partitioner

from executorch.exir.backend.utils import format_delegated_graph
Expand Down Expand Up @@ -125,6 +125,7 @@ def __init__(
# make sure to re-export this graph module to persist any changes. See
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
self.exported_module: Optional[torch.export.ExportedProgram] = None
self.edge_manager: Optional[EdgeProgramManager] = None
self.canonical_passes = [
RemoveRedundantTransposes()
Expand Down Expand Up @@ -235,17 +236,23 @@ def export(self) -> "LLMEdgeManager":
The full torch.export() if called later on during to_edge() or
to_edge_transform_and_lower().
"""
exported_module = self._export()
self.exported_module = self._export()
# Need to store the graph module to record transformation passes.
# Persisting those changes back to an ExportedProgram will require
# an additional export().
self.pre_autograd_graph_module = exported_module.module()
self.pre_autograd_graph_module = self.exported_module.module()
if self.save_exported_program:
export_output = f"{self.modelname}.pt2"
logging.info(f"Saving torch.export() result to {export_output}")
torch.export.save(exported_module, export_output)
return self

def run_decompositions(self) -> "LLMEdgeManager":
# Re-export to capture any pending changes to pre_autograd_graph_module
self.exported_module = self._export(self.pre_autograd_graph_module)
self.exported_module = self.exported_module.run_decompositions({})
return self

def run_canonical_optimizations(self):
"""
Run canonical optimizations (at the moment removing redundant permutes) on the model.
Expand All @@ -256,6 +263,8 @@ def run_canonical_optimizations(self):
res = pass_instance(self.pre_autograd_graph_module)
assert res.graph_module is not None, "Pass returned None"
self.pre_autograd_graph_module = res.graph_module
# Re-export to capture changes to pre_autograd_graph_module
self.exported_module = self._export(self.pre_autograd_graph_module)

def pt2e_calibrate(
self,
Expand Down Expand Up @@ -389,6 +398,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
# Re-export to capture changes to pre_autograd_graph_module
self.exported_module = self._export(self.pre_autograd_graph_module)
return self
else:
logging.info("No quantizer provided, passing...")
Expand All @@ -398,7 +409,6 @@ def export_to_edge(self) -> "LLMEdgeManager":
"""
Export the model to Edge dialect and retrieve a LLMEdgeManager.
"""
dynamic_shape = self._get_dynamic_shape()
edge_config = self._get_edge_config()

# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
Expand All @@ -408,17 +418,11 @@ def export_to_edge(self) -> "LLMEdgeManager":
# Run export() if it didn't run
self.export()

override_export_behaviour = contextlib.nullcontext()
with override_export_behaviour:
self.edge_manager = export_to_edge(
self.pre_autograd_graph_module, # pyre-fixme[6]
self.example_inputs,
example_kwarg_inputs=self.example_kwarg_inputs,
dynamic_shapes=dynamic_shape,
edge_constant_methods=self.metadata,
edge_compile_config=edge_config,
verbose=self.verbose,
)
self.edge_manager = to_edge(
self.exported_module,
constant_methods=self.metadata,
compile_config=edge_config,
)
return self

def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
Expand Down Expand Up @@ -457,12 +461,9 @@ def to_edge_transform_and_lower(
if partitioners is None:
logging.info("No partitioner provided, skipping backend lowering...")

# Need to construct ExportedProgram with the new transformed graph module.
exported_module = self._export(self.pre_autograd_graph_module)

edge_config = self._get_edge_config()
self.edge_manager = to_edge_transform_and_lower(
exported_module,
self.exported_module,
partitioner=partitioners,
compile_config=edge_config,
constant_methods=self.metadata,
Expand Down
Loading