@@ -874,6 +874,7 @@ def _to_edge_and_lower_llama_xnnpack(
874874 xnnpack_extended_ops : bool = False ,
875875 generate_etrecord : bool = False ,
876876 verbose : bool = False ,
877+ gen_tag_fn : Optional [Callable [[torch .fx .Node ], Optional [str ]]] = None ,
877878) -> LLMEdgeManager : # noqa: C901
878879 partitioners = []
879880
@@ -896,9 +897,27 @@ def _to_edge_and_lower_llama_xnnpack(
896897 if generate_etrecord :
897898 builder_exported .generate_etrecord = True
898899
899- builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
900- partitioners
901- )
900+ builder = builder_exported .pt2e_quantize (quantizers )
901+ if gen_tag_fn is not None :
902+ from executorch .exir .passes .external_constants_pass import (
903+ delegate_external_constants_pass_unlifted ,
904+ external_constants_pass ,
905+ )
906+
907+ assert (
908+ builder_exported .pre_autograd_graph_module is not None
909+ ), "pre_autograd_graph_module shouldn't be None here"
910+ delegate_external_constants_pass_unlifted (
911+ module = builder_exported .pre_autograd_graph_module ,
912+ gen_tag_fn = gen_tag_fn ,
913+ )
914+
915+ # Also add a pass for 'to_executorch' to tag weights that aren't delegated.
916+ additional_passes .append (
917+ partial (external_constants_pass , gen_tag_fn = gen_tag_fn )
918+ )
919+
920+ builder = builder .to_edge_transform_and_lower (partitioners )
902921 if verbose :
903922 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
904923
@@ -1136,6 +1155,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11361155 llm_config .backend .xnnpack .enabled = True
11371156
11381157 if llm_config .backend .xnnpack .enabled :
1158+ gen_tag_fn = None
11391159 if (
11401160 llm_config .export .foundation_weights_file is not None
11411161 or llm_config .export .lora_weights_file is not None
@@ -1145,24 +1165,6 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11451165 if "lora" not in x .name
11461166 else llm_config .export .lora_weights_file
11471167 )
1148- from executorch .exir .passes .external_constants_pass import (
1149- delegate_external_constants_pass_unlifted ,
1150- external_constants_pass ,
1151- )
1152-
1153- assert (
1154- builder_exported .pre_autograd_graph_module is not None
1155- ), "pre_autograd_graph_module shouldn't be None here"
1156- delegate_external_constants_pass_unlifted (
1157- module = builder_exported .pre_autograd_graph_module ,
1158- gen_tag_fn = gen_tag_fn ,
1159- )
1160-
1161- # Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1162- additional_passes .append (
1163- partial (external_constants_pass , gen_tag_fn = gen_tag_fn )
1164- )
1165-
11661168 builder = _to_edge_and_lower_llama_xnnpack (
11671169 builder_exported ,
11681170 modelname ,
@@ -1173,6 +1175,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11731175 xnnpack_extended_ops = llm_config .backend .xnnpack .extended_ops ,
11741176 generate_etrecord = llm_config .debug .generate_etrecord ,
11751177 verbose = llm_config .debug .verbose ,
1178+ gen_tag_fn = gen_tag_fn ,
11761179 )
11771180 elif llm_config .backend .openvino .enabled :
11781181 builder = _to_edge_and_lower_llama_openvino (
0 commit comments