Skip to content

Commit 1606420

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
add ability to compare intermedidate outputs
Summary: Add a path to leverage intermediate output comparator to eager_backbone_compare script Reviewed By: YIWENX14 Differential Revision: D80118857
1 parent 56f24c6 commit 1606420

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -863,16 +863,16 @@ def _to_edge_and_lower_llama_xnnpack(
863863

864864
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
865865
if generate_etrecord:
866-
raise NotImplementedError(
867-
"export_llama does not support XNNPack and generating ETRecord at the moment."
868-
)
866+
builder_exported.generate_etrecord=True
869867

870868
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
871869
partitioners
872870
)
873871
if verbose:
874872
print_delegation_info(builder.edge_manager.exported_program().graph_module)
875873

874+
# we need builder.export_program
875+
876876
return builder.to_executorch(passes=additional_passes)
877877

878878

extension/llm/export/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
metadata: Optional[dict] = None,
9797
dynamic_shapes: Optional[Any] = None,
9898
save_exported_program: bool = False,
99+
generate_etrecord: bool = False,
99100
):
100101
# Store necessary constructor arguments.
101102
self.model = model
@@ -116,6 +117,7 @@ def __init__(
116117
self.metadata = metadata
117118
self.dynamic_shapes = dynamic_shapes
118119
self.save_exported_program = save_exported_program
120+
self.generate_etrecord = generate_etrecord
119121

120122
# Note: treat this as the source of truth for the result of
121123
# torch.export'ing a model. If the overall ExportedProgram is needed,
@@ -481,6 +483,7 @@ def to_edge_transform_and_lower(
481483
partitioner=partitioners,
482484
compile_config=edge_config,
483485
constant_methods=self.metadata,
486+
generate_etrecord=self.generate_etrecord,
484487
)
485488
if self.verbose:
486489
logging.info(f"Exported graph:\n{self.edge_manager.exported_program()}")

0 commit comments

Comments
 (0)