Skip to content

Commit a3e47d0

Browse files
lucylqfacebook-github-bot
authored andcommitted
run_decomp2
Differential Revision: D88223584
1 parent 93bf861 commit a3e47d0

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,16 +879,21 @@ def _to_edge_and_lower_llama_xnnpack(
879879
builder_exported.generate_etrecord = True
880880

881881
builder = builder_exported.pt2e_quantize(quantizers)
882+
883+
# re-export required here?
884+
# run decomps
885+
builder = builder.run_decompositions()
886+
# tag EP
882887
if gen_tag_fn is not None:
883888
from executorch.exir.passes.external_constants_pass import (
884-
delegate_external_constants_pass_unlifted,
889+
delegate_external_constants_pass_lifted,
885890
)
886891

887892
assert (
888893
builder_exported.pre_autograd_graph_module is not None
889894
), "pre_autograd_graph_module shouldn't be None here"
890-
delegate_external_constants_pass_unlifted(
891-
module=builder_exported.pre_autograd_graph_module,
895+
delegate_external_constants_pass_lifted(
896+
ep=builder_exported.exported_module,
892897
gen_tag_fn=gen_tag_fn,
893898
)
894899

exir/passes/external_constants_pass.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,23 @@ def delegate_external_constants_pass_unlifted(
111111
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
112112
mutated = True
113113
return PassResult(module, mutated)
114+
115+
def delegate_external_constants_pass_lifted(
116+
ep: ExportedProgram,
117+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
118+
) -> PassResult:
119+
"""
120+
Tag constants in an ExportedProgram for external storage.
121+
Works on the lifted graph directly, no re-export needed.
122+
"""
123+
mutated = False
124+
gm = ep.graph_module
125+
126+
for node in gm.graph.nodes:
127+
if node.op == "placeholder" and is_param_node(ep, node):
128+
if gen_tag_fn is not None:
129+
node.meta.setdefault("custom", {})
130+
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
131+
mutated = True
132+
133+
return PassResult(gm, mutated)

extension/llm/export/builder.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
DuplicateDynamicQuantChainPass,
2121
)
2222
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
23-
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
23+
from executorch.exir import EdgeProgramManager, to_edge, to_edge_transform_and_lower
2424
from executorch.exir.backend.partitioner import Partitioner
2525

2626
from executorch.exir.backend.utils import format_delegated_graph
@@ -125,6 +125,7 @@ def __init__(
125125
# make sure to re-export this graph module to persist any changes. See
126126
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
127127
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
128+
self.exported_module: Optional[torch.export.ExportedProgram] = None
128129
self.edge_manager: Optional[EdgeProgramManager] = None
129130
self.canonical_passes = [
130131
RemoveRedundantTransposes()
@@ -235,17 +236,23 @@ def export(self) -> "LLMEdgeManager":
235236
The full torch.export() if called later on during to_edge() or
236237
to_edge_transform_and_lower().
237238
"""
238-
exported_module = self._export()
239+
self.exported_module = self._export()
239240
# Need to store the graph module to record transformation passes.
240241
# Persisting those changes back to an ExportedProgram will require
241242
# an additional export().
242-
self.pre_autograd_graph_module = exported_module.module()
243+
self.pre_autograd_graph_module = self.exported_module.module()
243244
if self.save_exported_program:
244245
export_output = f"{self.modelname}.pt2"
245246
logging.info(f"Saving torch.export() result to {export_output}")
246247
torch.export.save(exported_module, export_output)
247248
return self
248249

250+
def run_decompositions(self) -> "LLMEdgeManager":
251+
# Re-export to capture any pending changes to pre_autograd_graph_module
252+
self.exported_module = self._export(self.pre_autograd_graph_module)
253+
self.exported_module = self.exported_module.run_decompositions({})
254+
return self
255+
249256
def run_canonical_optimizations(self):
250257
"""
251258
Run canonical optimizations (at the moment removing redundant permutes) on the model.
@@ -256,6 +263,8 @@ def run_canonical_optimizations(self):
256263
res = pass_instance(self.pre_autograd_graph_module)
257264
assert res.graph_module is not None, "Pass returned None"
258265
self.pre_autograd_graph_module = res.graph_module
266+
# Re-export to capture changes to pre_autograd_graph_module
267+
self.exported_module = self._export(self.pre_autograd_graph_module)
259268

260269
def pt2e_calibrate(
261270
self,
@@ -389,6 +398,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
389398
m = convert_pt2e(m)
390399
DuplicateDynamicQuantChainPass()(m)
391400
self.pre_autograd_graph_module = m
401+
# Re-export to capture changes to pre_autograd_graph_module
402+
self.exported_module = self._export(self.pre_autograd_graph_module)
392403
return self
393404
else:
394405
logging.info("No quantizer provided, passing...")
@@ -398,7 +409,6 @@ def export_to_edge(self) -> "LLMEdgeManager":
398409
"""
399410
Export the model to Edge dialect and retrieve a LLMEdgeManager.
400411
"""
401-
dynamic_shape = self._get_dynamic_shape()
402412
edge_config = self._get_edge_config()
403413

404414
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
@@ -408,17 +418,11 @@ def export_to_edge(self) -> "LLMEdgeManager":
408418
# Run export() if it didn't run
409419
self.export()
410420

411-
override_export_behaviour = contextlib.nullcontext()
412-
with override_export_behaviour:
413-
self.edge_manager = export_to_edge(
414-
self.pre_autograd_graph_module, # pyre-fixme[6]
415-
self.example_inputs,
416-
example_kwarg_inputs=self.example_kwarg_inputs,
417-
dynamic_shapes=dynamic_shape,
418-
edge_constant_methods=self.metadata,
419-
edge_compile_config=edge_config,
420-
verbose=self.verbose,
421-
)
421+
self.edge_manager = to_edge(
422+
self.exported_module,
423+
constant_methods=self.metadata,
424+
compile_config=edge_config,
425+
)
422426
return self
423427

424428
def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
@@ -457,12 +461,9 @@ def to_edge_transform_and_lower(
457461
if partitioners is None:
458462
logging.info("No partitioner provided, skipping backend lowering...")
459463

460-
# Need to construct ExportedProgram with the new transformed graph module.
461-
exported_module = self._export(self.pre_autograd_graph_module)
462-
463464
edge_config = self._get_edge_config()
464465
self.edge_manager = to_edge_transform_and_lower(
465-
exported_module,
466+
self.exported_module,
466467
partitioner=partitioners,
467468
compile_config=edge_config,
468469
constant_methods=self.metadata,

0 commit comments

Comments
 (0)