Skip to content

Commit afec819

Browse files
committed
Tag constants after quantization
1 parent d4129b7 commit afec819

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def _to_edge_and_lower_llama_xnnpack(
854854
xnnpack_extended_ops: bool = False,
855855
generate_etrecord: bool = False,
856856
verbose: bool = False,
857+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
857858
) -> LLMEdgeManager: # noqa: C901
858859
partitioners = []
859860

@@ -876,9 +877,27 @@ def _to_edge_and_lower_llama_xnnpack(
876877
if generate_etrecord:
877878
builder_exported.generate_etrecord = True
878879

879-
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
880-
partitioners
881-
)
880+
builder = builder_exported.pt2e_quantize(quantizers)
881+
if gen_tag_fn is not None:
882+
from executorch.exir.passes.external_constants_pass import (
883+
delegate_external_constants_pass_unlifted,
884+
external_constants_pass,
885+
)
886+
887+
assert (
888+
builder_exported.pre_autograd_graph_module is not None
889+
), "pre_autograd_graph_module shouldn't be None here"
890+
delegate_external_constants_pass_unlifted(
891+
module=builder_exported.pre_autograd_graph_module,
892+
gen_tag_fn=gen_tag_fn,
893+
)
894+
895+
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
896+
additional_passes.append(
897+
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
898+
)
899+
900+
builder = builder.to_edge_transform_and_lower(partitioners)
882901
if verbose:
883902
print_delegation_info(builder.edge_manager.exported_program().graph_module)
884903

@@ -1088,31 +1107,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10881107
llm_config.backend.xnnpack.enabled = True
10891108

10901109
if llm_config.backend.xnnpack.enabled:
1110+
gen_tag_fn = None
10911111
if llm_config.export.foundation_weights_file is not None:
10921112
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
10931113
llm_config.export.foundation_weights_file
10941114
if "lora" not in x.name
10951115
else None
10961116
)
10971117

1098-
from executorch.exir.passes.external_constants_pass import (
1099-
delegate_external_constants_pass_unlifted,
1100-
external_constants_pass,
1101-
)
1102-
1103-
assert (
1104-
builder_exported.pre_autograd_graph_module is not None
1105-
), "pre_autograd_graph_module shouldn't be None here"
1106-
delegate_external_constants_pass_unlifted(
1107-
module=builder_exported.pre_autograd_graph_module,
1108-
gen_tag_fn=gen_tag_fn,
1109-
)
1110-
1111-
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1112-
additional_passes.append(
1113-
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
1114-
)
1115-
11161118
builder = _to_edge_and_lower_llama_xnnpack(
11171119
builder_exported,
11181120
modelname,
@@ -1123,6 +1125,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11231125
xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops,
11241126
generate_etrecord=llm_config.debug.generate_etrecord,
11251127
verbose=llm_config.debug.verbose,
1128+
gen_tag_fn=gen_tag_fn,
11261129
)
11271130
else:
11281131
builder = _to_edge_and_lower_llama(

0 commit comments

Comments
 (0)