diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 6d8ba54c010..b1cfb058882 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -874,6 +874,7 @@ def _to_edge_and_lower_llama_xnnpack( xnnpack_extended_ops: bool = False, generate_etrecord: bool = False, verbose: bool = False, + gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None, ) -> LLMEdgeManager: # noqa: C901 partitioners = [] @@ -896,9 +897,27 @@ def _to_edge_and_lower_llama_xnnpack( if generate_etrecord: builder_exported.generate_etrecord = True - builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( - partitioners - ) + builder = builder_exported.pt2e_quantize(quantizers) + if gen_tag_fn is not None: + from executorch.exir.passes.external_constants_pass import ( + delegate_external_constants_pass_unlifted, + external_constants_pass, + ) + + assert ( + builder_exported.pre_autograd_graph_module is not None + ), "pre_autograd_graph_module shouldn't be None here" + delegate_external_constants_pass_unlifted( + module=builder_exported.pre_autograd_graph_module, + gen_tag_fn=gen_tag_fn, + ) + + # Also add a pass for 'to_executorch' to tag weights that aren't delegated. + additional_passes.append( + partial(external_constants_pass, gen_tag_fn=gen_tag_fn) + ) + + builder = builder.to_edge_transform_and_lower(partitioners) if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) @@ -1136,6 +1155,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config.backend.xnnpack.enabled = True if llm_config.backend.xnnpack.enabled: + gen_tag_fn = None if ( llm_config.export.foundation_weights_file is not None or llm_config.export.lora_weights_file is not None @@ -1145,24 +1165,6 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 if "lora" not in x.name else llm_config.export.lora_weights_file ) - from executorch.exir.passes.external_constants_pass import ( - delegate_external_constants_pass_unlifted, - external_constants_pass, - ) - - assert ( - builder_exported.pre_autograd_graph_module is not None - ), "pre_autograd_graph_module shouldn't be None here" - delegate_external_constants_pass_unlifted( - module=builder_exported.pre_autograd_graph_module, - gen_tag_fn=gen_tag_fn, - ) - - # Also add a pass for 'to_executorch' to tag weights that aren't delegated. - additional_passes.append( - partial(external_constants_pass, gen_tag_fn=gen_tag_fn) - ) - builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, @@ -1173,6 +1175,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops, generate_etrecord=llm_config.debug.generate_etrecord, verbose=llm_config.debug.verbose, + gen_tag_fn=gen_tag_fn, ) elif llm_config.backend.openvino.enabled: builder = _to_edge_and_lower_llama_openvino(