Skip to content
Merged
12 changes: 7 additions & 5 deletions docs/source/using-executorch-export.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,16 @@ To generate a `model.pte`, `model.ptd` pair with the weights inside `model.ptd`,

```python
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass,
delegate_external_constants_pass_unlifted,
)
partial_function = partial(
delegate_external_constants_pass,
ep=exported_program,
# Tag the unlifted ep.module().
tagged_module = exported_program.module()
delegate_external_constants_pass_unlifted(
tagged_module,
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
)

# Re-export to get the EP.
exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes)
executorch_program = to_edge_transform_and_lower(
exported_program,
transform_passes = [partial_function],
Expand Down
31 changes: 0 additions & 31 deletions exir/passes/external_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,37 +88,6 @@ def external_mutable_weights_pass(
return PassResult(gm, mutated)


def delegate_external_constants_pass(
gm: GraphModule,
ep: ExportedProgram,
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
) -> PassResult:
"""
Tag external constants before to_backend.
Note: this pass must be run after run_decompositions(), as tags on
constants are removed then.
Args:
gm: GraphModule to tag.
ep: ExportedProgram, to distinguish if a node is a constant.
gen_tag_fn: node -> str callable indicating the tag for the node.
Returns:
PassResult: The resulting gm, and if it was mutated or not.
"""
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.op == "placeholder" and is_param_node(ep, node):
if gen_tag_fn is not None:
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)


# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
# and not on a lifted graph, e.g. ep.graph_module.
# This is using 'get_attr' to tag constants, which only appears in
Expand Down
13 changes: 5 additions & 8 deletions test/models/export_delegated_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ExecutorBackend,
)
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass,
delegate_external_constants_pass_unlifted,
)
from executorch.exir.program import ExecutorchProgramManager
from torch import nn
Expand Down Expand Up @@ -172,18 +172,15 @@ def forward(self, *args, **kwargs):
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackPartitioner,
)

transform_passes = []
if external_constants:
partial_function = partial(
delegate_external_constants_pass,
ep=exported_program,
tagged_module = exported_program.module()
delegate_external_constants_pass_unlifted(
gm=tagged_module,
gen_tag_fn=lambda x: module_class.__name__,
)
transform_passes.append(partial_function)
exported_program = export(tagged_module, args=inputs, strict=True)
executorch_program = to_edge_transform_and_lower(
exported_program,
transform_passes=transform_passes,
compile_config=edge_config,
partitioner=[XnnpackPartitioner()],
).to_executorch(config=et_config)
Expand Down
Loading