diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index bced97beef0..574a2327472 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1086,6 +1086,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass_unlifted, + external_constants_pass, ) assert ( @@ -1096,6 +1097,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 gen_tag_fn=gen_tag_fn, ) + # Also add a pass for 'to_executorch' to tag weights that aren't delegated. + additional_passes.append( + partial(external_constants_pass, gen_tag_fn=gen_tag_fn) + ) + builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index 1038af2ac7f..ba1d85adb4f 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -27,6 +27,7 @@ def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: def external_constants_pass( gm: GraphModule, + gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None, ) -> PassResult: """ Move all non-lifted constants to external file. @@ -42,7 +43,10 @@ def external_constants_pass( if (node.op == "placeholder") and ("_lifted_tensor" not in node.name): spec = node.meta.get("spec") if isinstance(spec, TensorSpec) and spec.const: - node.meta["constant_tag"] = "_default_external_constant" + if gen_tag_fn is not None: + node.meta["constant_tag"] = gen_tag_fn(node) + else: + node.meta["constant_tag"] = "_default_external_constant" mutated = True return PassResult(gm, mutated) diff --git a/exir/program/_program.py b/exir/program/_program.py index 8df41bed200..ea03eaa51b6 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -808,13 +808,13 @@ def edge_to_executorch_passes( Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. """ passes: List[PassType] = [ - *config.passes, SpecPropPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), + *config.passes, ] + pre_memory_planning_passes(config, name) return passes