diff --git a/docs/source/using-executorch-export.md b/docs/source/using-executorch-export.md index da9cadf3ec2..51347e3a3dc 100644 --- a/docs/source/using-executorch-export.md +++ b/docs/source/using-executorch-export.md @@ -129,14 +129,16 @@ To generate a `model.pte`, `model.ptd` pair with the weights inside `model.ptd`, ```python from executorch.exir.passes.external_constants_pass import ( - delegate_external_constants_pass, + delegate_external_constants_pass_unlifted, ) -partial_function = partial( - delegate_external_constants_pass, - ep=exported_program, +# Tag the unlifted ep.module(). +tagged_module = exported_program.module() +delegate_external_constants_pass_unlifted( + module=tagged_module, gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" ) - +# Re-export to get the EP. +exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes) executorch_program = to_edge_transform_and_lower( exported_program, transform_passes = [partial_function], diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ca940adb687..3a1801f063c 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1079,7 +1079,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 if llm_config.backend.xnnpack.enabled: if llm_config.export.foundation_weights_file is not None: - gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: ( + gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: ( llm_config.export.foundation_weights_file if "lora" not in x.name else None @@ -1089,8 +1089,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 delegate_external_constants_pass_unlifted, ) + assert ( + builder_exported.pre_autograd_graph_module is not None + ), "pre_autograd_graph_module shouldn't be None here" delegate_external_constants_pass_unlifted( - gm=builder_exported.pre_autograd_graph_module, + module=builder_exported.pre_autograd_graph_module, gen_tag_fn=gen_tag_fn, ) diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index 414e131d6f5..1038af2ac7f 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -88,53 +88,22 @@ def external_mutable_weights_pass( return PassResult(gm, mutated) -def delegate_external_constants_pass( - gm: GraphModule, - ep: ExportedProgram, - gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None, -) -> PassResult: - """ - Tag external constants before to_backend. - - Note: this pass must be run after run_decompositions(), as tags on - constants are removed then. - - Args: - gm: GraphModule to tag. - ep: ExportedProgram, to distinguish if a node is a constant. - gen_tag_fn: node -> str callable indicating the tag for the node. - Returns: - PassResult: The resulting gm, and if it was mutated or not. - """ - mutated = False - for module in gm.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - for node in module.graph.nodes: - if node.op == "placeholder" and is_param_node(ep, node): - if gen_tag_fn is not None: - node.meta.setdefault("custom", {}) - node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) - mutated = True - return PassResult(gm, mutated) - - # Note: this pass must be run on an unlifted graph, e.g. ep.module(), # and not on a lifted graph, e.g. ep.graph_module. # This is using 'get_attr' to tag constants, which only appears in # unlifted graphs. def delegate_external_constants_pass_unlifted( - gm: GraphModule, - gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None, + module: torch.nn.Module, + gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None, ) -> PassResult: mutated = False - for module in gm.modules(): - if not isinstance(module, torch.fx.GraphModule): + for m in module.modules(): + if not isinstance(m, torch.fx.GraphModule): continue - for node in module.graph.nodes: + for node in m.graph.nodes: if node.op == "get_attr": if gen_tag_fn is not None: node.meta.setdefault("custom", {}) node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) mutated = True - return PassResult(gm, mutated) + return PassResult(module, mutated) diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index cbfdfaedab3..8f7c388d7ad 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -11,7 +11,6 @@ import os import sys -from functools import partial from typing import Dict, final, Optional, Sequence, Type import executorch.exir as exir @@ -28,7 +27,7 @@ ExecutorBackend, ) from executorch.exir.passes.external_constants_pass import ( - delegate_external_constants_pass, + delegate_external_constants_pass_unlifted, ) from executorch.exir.program import ExecutorchProgramManager from torch import nn @@ -173,17 +172,15 @@ def forward(self, *args, **kwargs): XnnpackPartitioner, ) - transform_passes = [] if external_constants: - partial_function = partial( - delegate_external_constants_pass, - ep=exported_program, + tagged_module = exported_program.module() + delegate_external_constants_pass_unlifted( + module=tagged_module, gen_tag_fn=lambda x: module_class.__name__, ) - transform_passes.append(partial_function) + exported_program = export(tagged_module, args=inputs, strict=True) executorch_program = to_edge_transform_and_lower( exported_program, - transform_passes=transform_passes, compile_config=edge_config, partitioner=[XnnpackPartitioner()], ).to_executorch(config=et_config)