Skip to content

Commit 2a6376e

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
add ability to compare intermedidate outputs (#13482)
Summary: Add a path to leverage intermediate output comparator to eager_backbone_compare script Reviewed By: YIWENX14, JacobSzwejbka Differential Revision: D80118857
1 parent bbc281f commit 2a6376e

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -857,16 +857,16 @@ def _to_edge_and_lower_llama_xnnpack(
857857

858858
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
859859
if generate_etrecord:
860-
raise NotImplementedError(
861-
"export_llama does not support XNNPack and generating ETRecord at the moment."
862-
)
860+
builder_exported.generate_etrecord = True
863861

864862
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
865863
partitioners
866864
)
867865
if verbose:
868866
print_delegation_info(builder.edge_manager.exported_program().graph_module)
869867

868+
# we need builder.export_program
869+
870870
return builder.to_executorch(passes=additional_passes)
871871

872872

extension/llm/export/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
metadata: Optional[dict] = None,
9797
dynamic_shapes: Optional[Any] = None,
9898
save_exported_program: bool = False,
99+
generate_etrecord: bool = False,
99100
):
100101
# Store necessary constructor arguments.
101102
self.model = model
@@ -116,6 +117,7 @@ def __init__(
116117
self.metadata = metadata
117118
self.dynamic_shapes = dynamic_shapes
118119
self.save_exported_program = save_exported_program
120+
self.generate_etrecord = generate_etrecord
119121

120122
# Note: treat this as the source of truth for the result of
121123
# torch.export'ing a model. If the overall ExportedProgram is needed,
@@ -481,6 +483,7 @@ def to_edge_transform_and_lower(
481483
partitioner=partitioners,
482484
compile_config=edge_config,
483485
constant_methods=self.metadata,
486+
generate_etrecord=self.generate_etrecord,
484487
)
485488
if self.verbose:
486489
logging.info(f"Exported graph:\n{self.edge_manager.exported_program()}")

0 commit comments

Comments
 (0)