Skip to content

Commit 9832f31

Browse files
committed
Update on "Use unlifted export pass to tag delegated constants"
Use the unlifted pass to tag constants for delegates. Implications: - Tagging must happen on the unlifted ep.module(), before going into to_edge_transform_and_lower/to_edge. Why? - The unlifted graph contains constants in getattr nodes, which is a convenient way to isolate constants. - After going into to_edge_transform_and_lower/to_edge, transforms happen on the graph_module, which is lifted. - The lifted graph requires the ep graph signature to differentiate constants via the `is_param` function. However, in to_edge.transform, we do not have access to the ep. Baking the ep as an argument via partial function doesn't work, as the ep from earlier may be outdated. This means we are comparing an older ep to a newer graph_module, which may not have corresponding graph signatures etc. Differential Revision: [D79736684](https://our.internmc.facebook.com/intern/diff/D79736684/) [ghstack-poisoned]
2 parents d21ae81 + 880f674 commit 9832f31

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,10 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10881088
from executorch.exir.passes.external_constants_pass import (
10891089
delegate_external_constants_pass_unlifted,
10901090
)
1091-
assert builder_exported.pre_autograd_graph_module is not None, "pre_autograd_graph_module shouldn't be None here"
1091+
1092+
assert (
1093+
builder_exported.pre_autograd_graph_module is not None
1094+
), "pre_autograd_graph_module shouldn't be None here"
10921095
delegate_external_constants_pass_unlifted(
10931096
module=builder_exported.pre_autograd_graph_module,
10941097
gen_tag_fn=gen_tag_fn,

test/models/export_delegated_program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def forward(self, *args, **kwargs):
171171
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
172172
XnnpackPartitioner,
173173
)
174+
174175
if external_constants:
175176
tagged_module = exported_program.module()
176177
delegate_external_constants_pass_unlifted(

0 commit comments

Comments
 (0)