Skip to content

Commit 168d750

Browse files
committed
move undelegated constants
1 parent 49bc664 commit 168d750

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10801080

10811081
from executorch.exir.passes.external_constants_pass import (
10821082
delegate_external_constants_pass_unlifted,
1083+
external_constants_pass,
10831084
)
10841085

10851086
assert (
@@ -1090,6 +1091,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10901091
gen_tag_fn=gen_tag_fn,
10911092
)
10921093

1094+
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1095+
additional_passes.append(
1096+
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
1097+
)
1098+
10931099
builder = _to_edge_and_lower_llama_xnnpack(
10941100
builder_exported,
10951101
modelname,

exir/passes/external_constants_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
2727

2828
def external_constants_pass(
2929
gm: GraphModule,
30+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
3031
) -> PassResult:
3132
"""
3233
Move all non-lifted constants to external file.
@@ -42,7 +43,10 @@ def external_constants_pass(
4243
if (node.op == "placeholder") and ("_lifted_tensor" not in node.name):
4344
spec = node.meta.get("spec")
4445
if isinstance(spec, TensorSpec) and spec.const:
45-
node.meta["constant_tag"] = "_default_external_constant"
46+
if gen_tag_fn is not None:
47+
node.meta["constant_tag"] = gen_tag_fn(node)
48+
else:
49+
node.meta["constant_tag"] = "_default_external_constant"
4650
mutated = True
4751
return PassResult(gm, mutated)
4852

exir/program/_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,11 +808,13 @@ def edge_to_executorch_passes(
808808
Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
809809
"""
810810
passes: List[PassType] = [
811-
*config.passes,
812811
SpecPropPass(),
813812
# ExecuTorch backend ops are unable to handle unbacked symints. So after
814813
# this pass, passes cannot be Interpreter-based, because it will fail if
815814
# there exists an unbacked symint operation.
815+
*config.passes,
816+
# config.passes may contain external_constants_pass. This pass has to
817+
# run after SpecPropPass, which populates tensor names.
816818
EdgeToBackendOpsPass(),
817819
RemoveGraphAssertsPass(),
818820
] + pre_memory_planning_passes(config, name)

0 commit comments

Comments
 (0)