Skip to content

Commit d69aa0f

Browse files
committed
move undelegated constants
1 parent fe84495 commit d69aa0f

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10861086

10871087
from executorch.exir.passes.external_constants_pass import (
10881088
delegate_external_constants_pass_unlifted,
1089+
external_constants_pass,
10891090
)
10901091

10911092
assert (
@@ -1096,6 +1097,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10961097
gen_tag_fn=gen_tag_fn,
10971098
)
10981099

1100+
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1101+
additional_passes.append(
1102+
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
1103+
)
1104+
10991105
builder = _to_edge_and_lower_llama_xnnpack(
11001106
builder_exported,
11011107
modelname,

exir/passes/external_constants_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
2727

2828
def external_constants_pass(
2929
gm: GraphModule,
30+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
3031
) -> PassResult:
3132
"""
3233
Move all non-lifted constants to external file.
@@ -42,7 +43,10 @@ def external_constants_pass(
4243
if (node.op == "placeholder") and ("_lifted_tensor" not in node.name):
4344
spec = node.meta.get("spec")
4445
if isinstance(spec, TensorSpec) and spec.const:
45-
node.meta["constant_tag"] = "_default_external_constant"
46+
if gen_tag_fn is not None:
47+
node.meta["constant_tag"] = gen_tag_fn(node)
48+
else:
49+
node.meta["constant_tag"] = "_default_external_constant"
4650
mutated = True
4751
return PassResult(gm, mutated)
4852

exir/program/_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,13 +808,13 @@ def edge_to_executorch_passes(
808808
Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
809809
"""
810810
passes: List[PassType] = [
811-
*config.passes,
812811
SpecPropPass(),
813812
# ExecuTorch backend ops are unable to handle unbacked symints. So after
814813
# this pass, passes cannot be Interpreter-based, because it will fail if
815814
# there exists an unbacked symint operation.
816815
EdgeToBackendOpsPass(),
817816
RemoveGraphAssertsPass(),
817+
*config.passes,
818818
] + pre_memory_planning_passes(config, name)
819819

820820
return passes

0 commit comments

Comments
 (0)