Skip to content

Commit 79a66cc

Browse files
author
Zonglin Peng
committed
fix minor issues and only gen binary
1 parent b8538ac commit 79a66cc

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

backends/cadence/aot/export_example.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def export_model(
6060
model: nn.Module,
6161
example_inputs: Tuple[Any, ...],
6262
file_name: str = "CadenceDemoModel",
63+
run_and_compare: bool = True,
6364
):
6465
# create work directory for outputs and model binary
6566
working_dir = tempfile.mkdtemp(dir="/tmp")
@@ -112,9 +113,10 @@ def export_model(
112113
)
113114

114115
# TODO: move to test infra
115-
runtime.run_and_compare(
116-
executorch_prog=exec_prog,
117-
inputs=example_inputs,
118-
ref_outputs=ref_outputs,
119-
working_dir=working_dir,
120-
)
116+
if run_and_compare:
117+
runtime.run_and_compare(
118+
executorch_prog=exec_prog,
119+
inputs=example_inputs,
120+
ref_outputs=ref_outputs,
121+
working_dir=working_dir,
122+
)

backends/cadence/aot/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def print_ops_info(
162162

163163
# Print the final ops and their counts in a tabular format
164164
logging.info(
165-
tabulate(
165+
"\n"
166+
+ tabulate(
166167
sorted_ops_count,
167168
headers=[
168169
"Final Operators ", # one character longer than the longest op name

examples/cadence/operators/add_op.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
5757
Y = torch.randn(Yshape)
5858

5959
model.eval()
60-
export_model(model, (X, Y))
60+
export_model(
61+
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
62+
)
6163

6264
@parameterized.expand(
6365
[
@@ -104,4 +106,10 @@ def forward(self, x: torch.Tensor, y: float):
104106
Y = 2.34
105107

106108
model.eval()
107-
export_model(model, (X, Y))
109+
export_model(
110+
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)