@@ -854,6 +854,7 @@ def _to_edge_and_lower_llama_xnnpack(
854854 xnnpack_extended_ops : bool = False ,
855855 generate_etrecord : bool = False ,
856856 verbose : bool = False ,
857+ gen_tag_fn : Optional [Callable [[torch .fx .Node ], Optional [str ]]] = None ,
857858) -> LLMEdgeManager : # noqa: C901
858859 partitioners = []
859860
@@ -876,9 +877,26 @@ def _to_edge_and_lower_llama_xnnpack(
876877 if generate_etrecord :
877878 builder_exported .generate_etrecord = True
878879
879- builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
880- partitioners
881- )
880+ builder = builder_exported .pt2e_quantize (quantizers )
881+ if gen_tag_fn is not None :
882+ from executorch .exir .passes .external_constants_pass import (
883+ delegate_external_constants_pass_unlifted ,
884+ external_constants_pass ,
885+ )
886+ assert (
887+ builder_exported .pre_autograd_graph_module is not None
888+ ), "pre_autograd_graph_module shouldn't be None here"
889+ delegate_external_constants_pass_unlifted (
890+ module = builder_exported .pre_autograd_graph_module ,
891+ gen_tag_fn = gen_tag_fn ,
892+ )
893+
894+ # Also add a pass for 'to_executorch' to tag weights that aren't delegated.
895+ additional_passes .append (
896+ partial (external_constants_pass , gen_tag_fn = gen_tag_fn )
897+ )
898+
899+ builder = builder .to_edge_transform_and_lower (partitioners )
882900 if verbose :
883901 print_delegation_info (builder .edge_manager .exported_program ().graph_module )
884902
@@ -1088,31 +1106,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10881106 llm_config .backend .xnnpack .enabled = True
10891107
10901108 if llm_config .backend .xnnpack .enabled :
1109+ gen_tag_fn = None
10911110 if llm_config .export .foundation_weights_file is not None :
10921111 gen_tag_fn : Callable [[torch .fx .Node ], Optional [str ]] = lambda x : (
10931112 llm_config .export .foundation_weights_file
10941113 if "lora" not in x .name
10951114 else None
10961115 )
10971116
1098- from executorch .exir .passes .external_constants_pass import (
1099- delegate_external_constants_pass_unlifted ,
1100- external_constants_pass ,
1101- )
1102-
1103- assert (
1104- builder_exported .pre_autograd_graph_module is not None
1105- ), "pre_autograd_graph_module shouldn't be None here"
1106- delegate_external_constants_pass_unlifted (
1107- module = builder_exported .pre_autograd_graph_module ,
1108- gen_tag_fn = gen_tag_fn ,
1109- )
1110-
1111- # Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1112- additional_passes .append (
1113- partial (external_constants_pass , gen_tag_fn = gen_tag_fn )
1114- )
1115-
11161117 builder = _to_edge_and_lower_llama_xnnpack (
11171118 builder_exported ,
11181119 modelname ,
@@ -1123,6 +1124,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11231124 xnnpack_extended_ops = llm_config .backend .xnnpack .extended_ops ,
11241125 generate_etrecord = llm_config .debug .generate_etrecord ,
11251126 verbose = llm_config .debug .verbose ,
1127+ gen_tag_fn = gen_tag_fn ,
11261128 )
11271129 else :
11281130 builder = _to_edge_and_lower_llama (
0 commit comments