Skip to content

Commit b782bb5

Browse files
Update
[ghstack-poisoned]
2 parents 16d863c + 6e0c9f6 commit b782bb5

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
ExecutorchProgramManager,
3939
)
4040
from executorch.exir.passes import ToOutVarPass
41-
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
41+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
4242
from executorch.exir.program._program import to_edge
4343

4444
from torch.export.exported_program import ExportedProgram
@@ -460,7 +460,7 @@ def _lower_ep_to_cadence_gen_etrecord(
460460
emit_stacktrace=False,
461461
to_out_var_pass=ToOutVarPass(),
462462
extract_delegate_segments=False,
463-
sym_shape_eval_pass=HintBasedSymShapeEvalPass(),
463+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
464464
),
465465
)
466466

examples/models/llama/export_llama_lib.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ def _to_edge_and_lower_llama_xnnpack(
874874
xnnpack_extended_ops: bool = False,
875875
generate_etrecord: bool = False,
876876
verbose: bool = False,
877+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
877878
) -> LLMEdgeManager: # noqa: C901
878879
partitioners = []
879880

@@ -896,9 +897,27 @@ def _to_edge_and_lower_llama_xnnpack(
896897
if generate_etrecord:
897898
builder_exported.generate_etrecord = True
898899

899-
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
900-
partitioners
901-
)
900+
builder = builder_exported.pt2e_quantize(quantizers)
901+
if gen_tag_fn is not None:
902+
from executorch.exir.passes.external_constants_pass import (
903+
delegate_external_constants_pass_unlifted,
904+
external_constants_pass,
905+
)
906+
907+
assert (
908+
builder_exported.pre_autograd_graph_module is not None
909+
), "pre_autograd_graph_module shouldn't be None here"
910+
delegate_external_constants_pass_unlifted(
911+
module=builder_exported.pre_autograd_graph_module,
912+
gen_tag_fn=gen_tag_fn,
913+
)
914+
915+
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
916+
additional_passes.append(
917+
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
918+
)
919+
920+
builder = builder.to_edge_transform_and_lower(partitioners)
902921
if verbose:
903922
print_delegation_info(builder.edge_manager.exported_program().graph_module)
904923

@@ -1136,6 +1155,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11361155
llm_config.backend.xnnpack.enabled = True
11371156

11381157
if llm_config.backend.xnnpack.enabled:
1158+
gen_tag_fn = None
11391159
if (
11401160
llm_config.export.foundation_weights_file is not None
11411161
or llm_config.export.lora_weights_file is not None
@@ -1145,24 +1165,6 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11451165
if "lora" not in x.name
11461166
else llm_config.export.lora_weights_file
11471167
)
1148-
from executorch.exir.passes.external_constants_pass import (
1149-
delegate_external_constants_pass_unlifted,
1150-
external_constants_pass,
1151-
)
1152-
1153-
assert (
1154-
builder_exported.pre_autograd_graph_module is not None
1155-
), "pre_autograd_graph_module shouldn't be None here"
1156-
delegate_external_constants_pass_unlifted(
1157-
module=builder_exported.pre_autograd_graph_module,
1158-
gen_tag_fn=gen_tag_fn,
1159-
)
1160-
1161-
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1162-
additional_passes.append(
1163-
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
1164-
)
1165-
11661168
builder = _to_edge_and_lower_llama_xnnpack(
11671169
builder_exported,
11681170
modelname,
@@ -1173,6 +1175,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11731175
xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops,
11741176
generate_etrecord=llm_config.debug.generate_etrecord,
11751177
verbose=llm_config.debug.verbose,
1178+
gen_tag_fn=gen_tag_fn,
11761179
)
11771180
elif llm_config.backend.openvino.enabled:
11781181
builder = _to_edge_and_lower_llama_openvino(

0 commit comments

Comments
 (0)