Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/source/using-executorch-export.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,16 @@ To generate a `model.pte`, `model.ptd` pair with the weights inside `model.ptd`,

```python
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass,
delegate_external_constants_pass_unlifted,
)
partial_function = partial(
delegate_external_constants_pass,
ep=exported_program,
# Tag the unlifted ep.module().
tagged_module = exported_program.module()
delegate_external_constants_pass_unlifted(
module=tagged_module,
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
)

# Re-export to get the EP.
exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes)
executorch_program = to_edge_transform_and_lower(
exported_program,
transform_passes = [partial_function],
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901

if llm_config.backend.xnnpack.enabled:
if llm_config.export.foundation_weights_file is not None:
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
llm_config.export.foundation_weights_file
if "lora" not in x.name
else None
Expand All @@ -1089,8 +1089,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
delegate_external_constants_pass_unlifted,
)

assert (
builder_exported.pre_autograd_graph_module is not None
), "pre_autograd_graph_module shouldn't be None here"
delegate_external_constants_pass_unlifted(
gm=builder_exported.pre_autograd_graph_module,
module=builder_exported.pre_autograd_graph_module,
gen_tag_fn=gen_tag_fn,
)

Expand Down
43 changes: 6 additions & 37 deletions exir/passes/external_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,53 +88,22 @@ def external_mutable_weights_pass(
return PassResult(gm, mutated)


def delegate_external_constants_pass(
gm: GraphModule,
ep: ExportedProgram,
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
) -> PassResult:
"""
Tag external constants before to_backend.

Note: this pass must be run after run_decompositions(), as tags on
constants are removed then.

Args:
gm: GraphModule to tag.
ep: ExportedProgram, to distinguish if a node is a constant.
gen_tag_fn: node -> str callable indicating the tag for the node.
Returns:
PassResult: The resulting gm, and if it was mutated or not.
"""
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.op == "placeholder" and is_param_node(ep, node):
if gen_tag_fn is not None:
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)


# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
# and not on a lifted graph, e.g. ep.graph_module.
# This is using 'get_attr' to tag constants, which only appears in
# unlifted graphs.
def delegate_external_constants_pass_unlifted(
gm: GraphModule,
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
module: torch.nn.Module,
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
) -> PassResult:
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
for m in module.modules():
if not isinstance(m, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
for node in m.graph.nodes:
if node.op == "get_attr":
if gen_tag_fn is not None:
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)
return PassResult(module, mutated)
13 changes: 5 additions & 8 deletions test/models/export_delegated_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import os
import sys

from functools import partial
from typing import Dict, final, Optional, Sequence, Type

import executorch.exir as exir
Expand All @@ -28,7 +27,7 @@
ExecutorBackend,
)
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass,
delegate_external_constants_pass_unlifted,
)
from executorch.exir.program import ExecutorchProgramManager
from torch import nn
Expand Down Expand Up @@ -173,17 +172,15 @@ def forward(self, *args, **kwargs):
XnnpackPartitioner,
)

transform_passes = []
if external_constants:
partial_function = partial(
delegate_external_constants_pass,
ep=exported_program,
tagged_module = exported_program.module()
delegate_external_constants_pass_unlifted(
module=tagged_module,
gen_tag_fn=lambda x: module_class.__name__,
)
transform_passes.append(partial_function)
exported_program = export(tagged_module, args=inputs, strict=True)
executorch_program = to_edge_transform_and_lower(
exported_program,
transform_passes=transform_passes,
compile_config=edge_config,
partitioner=[XnnpackPartitioner()],
).to_executorch(config=et_config)
Expand Down
Loading