Skip to content

Commit 8f84f24

Browse files
committed
Tag constants after quantization
1 parent d4129b7 commit 8f84f24

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def _to_edge_and_lower_llama_xnnpack(
854854
xnnpack_extended_ops: bool = False,
855855
generate_etrecord: bool = False,
856856
verbose: bool = False,
857+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
857858
) -> LLMEdgeManager: # noqa: C901
858859
partitioners = []
859860

@@ -876,9 +877,26 @@ def _to_edge_and_lower_llama_xnnpack(
876877
if generate_etrecord:
877878
builder_exported.generate_etrecord = True
878879

879-
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
880-
partitioners
881-
)
880+
builder = builder_exported.pt2e_quantize(quantizers)
881+
if gen_tag_fn is not None:
882+
from executorch.exir.passes.external_constants_pass import (
883+
delegate_external_constants_pass_unlifted,
884+
external_constants_pass,
885+
)
886+
assert (
887+
builder_exported.pre_autograd_graph_module is not None
888+
), "pre_autograd_graph_module shouldn't be None here"
889+
delegate_external_constants_pass_unlifted(
890+
module=builder_exported.pre_autograd_graph_module,
891+
gen_tag_fn=gen_tag_fn,
892+
)
893+
894+
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
895+
additional_passes.append(
896+
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
897+
)
898+
899+
builder = builder.to_edge_transform_and_lower(partitioners)
882900
if verbose:
883901
print_delegation_info(builder.edge_manager.exported_program().graph_module)
884902

@@ -1088,31 +1106,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10881106
llm_config.backend.xnnpack.enabled = True
10891107

10901108
if llm_config.backend.xnnpack.enabled:
1109+
gen_tag_fn = None
10911110
if llm_config.export.foundation_weights_file is not None:
10921111
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
10931112
llm_config.export.foundation_weights_file
10941113
if "lora" not in x.name
10951114
else None
10961115
)
10971116

1098-
from executorch.exir.passes.external_constants_pass import (
1099-
delegate_external_constants_pass_unlifted,
1100-
external_constants_pass,
1101-
)
1102-
1103-
assert (
1104-
builder_exported.pre_autograd_graph_module is not None
1105-
), "pre_autograd_graph_module shouldn't be None here"
1106-
delegate_external_constants_pass_unlifted(
1107-
module=builder_exported.pre_autograd_graph_module,
1108-
gen_tag_fn=gen_tag_fn,
1109-
)
1110-
1111-
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1112-
additional_passes.append(
1113-
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
1114-
)
1115-
11161117
builder = _to_edge_and_lower_llama_xnnpack(
11171118
builder_exported,
11181119
modelname,
@@ -1123,6 +1124,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11231124
xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops,
11241125
generate_etrecord=llm_config.debug.generate_etrecord,
11251126
verbose=llm_config.debug.verbose,
1127+
gen_tag_fn=gen_tag_fn,
11261128
)
11271129
else:
11281130
builder = _to_edge_and_lower_llama(

0 commit comments

Comments
 (0)