From f355798e8f5eccfa1a0a3d01967687b979c59580 Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 13 Aug 2025 21:18:17 -0700 Subject: [PATCH] Use unlifted export pass to tag delegated constants Pull Request resolved: https://github.com/pytorch/executorch/pull/13163 Use the unlifted pass to tag constants for delegates. Implications: - Tagging must happen on the unlifted ep.module(), before going into to_edge_transform_and_lower/to_edge. Why? - The unlifted graph contains constants in getattr nodes, which is a convenient way to isolate constants. - After going into to_edge_transform_and_lower/to_edge, transforms happen on the graph_module, which is lifted. - The lifted graph requires the ep graph signature to differentiate constants via the `is_param` function. However, in to_edge.transform, we do not have access to the ep. Baking the ep as an argument via partial function doesn't work, as the ep from earlier may be outdated. This means we are comparing an older ep to a newer graph_module, which may not have corresponding graph signatures etc. ghstack-source-id: 302925208 Differential Revision: [D79736684](https://our.internmc.facebook.com/intern/diff/D79736684/) --- docs/source/using-executorch-export.md | 12 ++++--- examples/models/llama/export_llama_lib.py | 7 ++-- exir/passes/external_constants_pass.py | 43 ++++------------------- test/models/export_delegated_program.py | 13 +++---- 4 files changed, 23 insertions(+), 52 deletions(-) diff --git a/docs/source/using-executorch-export.md b/docs/source/using-executorch-export.md index da9cadf3ec2..51347e3a3dc 100644 --- a/docs/source/using-executorch-export.md +++ b/docs/source/using-executorch-export.md @@ -129,14 +129,16 @@ To generate a `model.pte`, `model.ptd` pair with the weights inside `model.ptd`, ```python from executorch.exir.passes.external_constants_pass import ( - delegate_external_constants_pass, + delegate_external_constants_pass_unlifted, ) -partial_function = partial( - delegate_external_constants_pass, - ep=exported_program, +# Tag the unlifted ep.module(). +tagged_module = exported_program.module() +delegate_external_constants_pass_unlifted( + module=tagged_module, gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" ) - +# Re-export to get the EP. +exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes) executorch_program = to_edge_transform_and_lower( exported_program, transform_passes = [partial_function], diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ca940adb687..3a1801f063c 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1079,7 +1079,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 if llm_config.backend.xnnpack.enabled: if llm_config.export.foundation_weights_file is not None: - gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: ( + gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: ( llm_config.export.foundation_weights_file if "lora" not in x.name else None @@ -1089,8 +1089,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 delegate_external_constants_pass_unlifted, ) + assert ( + builder_exported.pre_autograd_graph_module is not None + ), "pre_autograd_graph_module shouldn't be None here" delegate_external_constants_pass_unlifted( - gm=builder_exported.pre_autograd_graph_module, + module=builder_exported.pre_autograd_graph_module, gen_tag_fn=gen_tag_fn, ) diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index 414e131d6f5..1038af2ac7f 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -88,53 +88,22 @@ def external_mutable_weights_pass( return PassResult(gm, mutated) -def delegate_external_constants_pass( - gm: GraphModule, - ep: ExportedProgram, - gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None, -) -> PassResult: - """ - Tag external constants before to_backend. - - Note: this pass must be run after run_decompositions(), as tags on - constants are removed then. - - Args: - gm: GraphModule to tag. - ep: ExportedProgram, to distinguish if a node is a constant. - gen_tag_fn: node -> str callable indicating the tag for the node. - Returns: - PassResult: The resulting gm, and if it was mutated or not. - """ - mutated = False - for module in gm.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - for node in module.graph.nodes: - if node.op == "placeholder" and is_param_node(ep, node): - if gen_tag_fn is not None: - node.meta.setdefault("custom", {}) - node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) - mutated = True - return PassResult(gm, mutated) - - # Note: this pass must be run on an unlifted graph, e.g. ep.module(), # and not on a lifted graph, e.g. ep.graph_module. # This is using 'get_attr' to tag constants, which only appears in # unlifted graphs. def delegate_external_constants_pass_unlifted( - gm: GraphModule, - gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None, + module: torch.nn.Module, + gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None, ) -> PassResult: mutated = False - for module in gm.modules(): - if not isinstance(module, torch.fx.GraphModule): + for m in module.modules(): + if not isinstance(m, torch.fx.GraphModule): continue - for node in module.graph.nodes: + for node in m.graph.nodes: if node.op == "get_attr": if gen_tag_fn is not None: node.meta.setdefault("custom", {}) node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) mutated = True - return PassResult(gm, mutated) + return PassResult(module, mutated) diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index cbfdfaedab3..8f7c388d7ad 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -11,7 +11,6 @@ import os import sys -from functools import partial from typing import Dict, final, Optional, Sequence, Type import executorch.exir as exir @@ -28,7 +27,7 @@ ExecutorBackend, ) from executorch.exir.passes.external_constants_pass import ( - delegate_external_constants_pass, + delegate_external_constants_pass_unlifted, ) from executorch.exir.program import ExecutorchProgramManager from torch import nn @@ -173,17 +172,15 @@ def forward(self, *args, **kwargs): XnnpackPartitioner, ) - transform_passes = [] if external_constants: - partial_function = partial( - delegate_external_constants_pass, - ep=exported_program, + tagged_module = exported_program.module() + delegate_external_constants_pass_unlifted( + module=tagged_module, gen_tag_fn=lambda x: module_class.__name__, ) - transform_passes.append(partial_function) + exported_program = export(tagged_module, args=inputs, strict=True) executorch_program = to_edge_transform_and_lower( exported_program, - transform_passes=transform_passes, compile_config=edge_config, partitioner=[XnnpackPartitioner()], ).to_executorch(config=et_config)