Skip to content

Commit 60624a5

Browse files
committed
Use unlifted export pass to tag delegated constants
Use the unlifted pass to tag constants for delegates. Implications: - Tagging must happen on the unlifted ep.module(), before going into to_edge_transform_and_lower/to_edge. Why? - The unlifted graph contains constants in getattr nodes, which is a convenient way to isolate constants. - After going into to_edge_transform_and_lower/to_edge, transforms happen on the graph_module, which is lifted. - The lifted graph requires the ep graph signature to differentiate constants via the `is_param` function. However, in to_edge.transform, we do not have access to the ep. Baking the ep as an argument via partial function doesn't work, as the ep from earlier may be outdated. This means we are comparing an older ep to a newer graph_module, which may not have corresponding graph signatures etc. Differential Revision: [D79736684](https://our.internmc.facebook.com/intern/diff/D79736684/) [ghstack-poisoned]
1 parent b64b1af commit 60624a5

File tree

3 files changed

+12
-44
lines changed

3 files changed

+12
-44
lines changed

docs/source/using-executorch-export.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,16 @@ To generate a `model.pte`, `model.ptd` pair with the weights inside `model.ptd`,
129129

130130
```python
131131
from executorch.exir.passes.external_constants_pass import (
132-
delegate_external_constants_pass,
132+
delegate_external_constants_pass_unlifted,
133133
)
134-
partial_function = partial(
135-
delegate_external_constants_pass,
136-
ep=exported_program,
134+
# Tag the unlifted ep.module().
135+
tagged_module = exported_program.module()
136+
delegate_external_constants_pass_unlifted(
137+
tagged_module,
137138
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
138139
)
139-
140+
# Re-export to get the EP.
141+
exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes)
140142
executorch_program = to_edge_transform_and_lower(
141143
exported_program,
142144
transform_passes = [partial_function],

exir/passes/external_constants_pass.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -88,37 +88,6 @@ def external_mutable_weights_pass(
8888
return PassResult(gm, mutated)
8989

9090

91-
def delegate_external_constants_pass(
92-
gm: GraphModule,
93-
ep: ExportedProgram,
94-
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
95-
) -> PassResult:
96-
"""
97-
Tag external constants before to_backend.
98-
99-
Note: this pass must be run after run_decompositions(), as tags on
100-
constants are removed then.
101-
102-
Args:
103-
gm: GraphModule to tag.
104-
ep: ExportedProgram, to distinguish if a node is a constant.
105-
gen_tag_fn: node -> str callable indicating the tag for the node.
106-
Returns:
107-
PassResult: The resulting gm, and if it was mutated or not.
108-
"""
109-
mutated = False
110-
for module in gm.modules():
111-
if not isinstance(module, torch.fx.GraphModule):
112-
continue
113-
for node in module.graph.nodes:
114-
if node.op == "placeholder" and is_param_node(ep, node):
115-
if gen_tag_fn is not None:
116-
node.meta.setdefault("custom", {})
117-
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
118-
mutated = True
119-
return PassResult(gm, mutated)
120-
121-
12291
# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
12392
# and not on a lifted graph, e.g. ep.graph_module.
12493
# This is using 'get_attr' to tag constants, which only appears in

test/models/export_delegated_program.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ExecutorBackend,
2929
)
3030
from executorch.exir.passes.external_constants_pass import (
31-
delegate_external_constants_pass,
31+
delegate_external_constants_pass_unlifted,
3232
)
3333
from executorch.exir.program import ExecutorchProgramManager
3434
from torch import nn
@@ -172,18 +172,15 @@ def forward(self, *args, **kwargs):
172172
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
173173
XnnpackPartitioner,
174174
)
175-
176-
transform_passes = []
177175
if external_constants:
178-
partial_function = partial(
179-
delegate_external_constants_pass,
180-
ep=exported_program,
176+
tagged_module = exported_program.module()
177+
delegate_external_constants_pass_unlifted(
178+
gm=tagged_module,
181179
gen_tag_fn=lambda x: module_class.__name__,
182180
)
183-
transform_passes.append(partial_function)
181+
exported_program = export(tagged_module, args=inputs, strict=True)
184182
executorch_program = to_edge_transform_and_lower(
185183
exported_program,
186-
transform_passes=transform_passes,
187184
compile_config=edge_config,
188185
partitioner=[XnnpackPartitioner()],
189186
).to_executorch(config=et_config)

0 commit comments

Comments
 (0)