Skip to content

Commit f125d74

Browse files
authored
Only check unwanted ops
1 parent c5ed931 commit f125d74

File tree

1 file changed

+2
-33
lines changed

1 file changed

+2
-33
lines changed

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,42 +11,15 @@
1111
from executorch.devtools.backend_debug import get_delegation_info
1212
from executorch.examples.models.llama.export_llama_lib import _export_llama, build_args_parser
1313

14-
15-
# Ops expected to be found in the default exported llama_transformer. Obtained through
16-
# print_delegation_info from the backend_debug module, which is displayed with
17-
# export_llama under --verbose.
18-
BASE_EXPECTED_OPS = {
19-
"sym_size": 1,
20-
"alloc": 288,
21-
"aten_embedding_default": 1,
22-
"aten_select_copy_int": 12,
23-
"_local_scalar_dense": 11,
24-
"add": 1,
25-
"aten_slice_copy_tensor": 23,
26-
"aten_mul_tensor": 83,
27-
"aten_mean_dim": 11,
28-
"aten_add_tensor": 31,
29-
"aten_rsqrt_default": 11,
30-
"aten_linear_default": 36,
31-
"aten_view_copy_default": 40,
32-
"aten_squeeze_copy_dims": 20,
33-
"aten_sub_tensor": 10,
34-
"aten_unsqueeze_copy_default": 20,
35-
"aten_cat_default": 10,
36-
"update_cache": 10,
37-
"llama_custom_sdpa_default": 5,
38-
"aten_sigmoid_default": 5,
39-
}
40-
4114
UNWANTED_OPS = [
4215
"aten_permute_copy_default",
16+
"aten_transpose_copy_default",
4317
]
4418

4519
class ExportLlamaLibTest(unittest.TestCase):
4620
def test_has_expected_ops_and_op_counts(self):
4721
"""
48-
Tests that the presence of expected ops and counts for each op are
49-
do not change.
22+
Checks the presence of unwanted expensive ops.
5023
5124
Serves as a proxy for a performance regression test, as performance
5225
is directly tied to which and how many of each ops are in the graph.
@@ -73,9 +46,5 @@ def test_has_expected_ops_and_op_counts(self):
7346
delegation_info = get_delegation_info(graph_module)
7447

7548
for op, op_info in delegation_info.delegation_by_operator.items():
76-
self.assertTrue(op in BASE_EXPECTED_OPS)
7749
self.assertTrue(op not in UNWANTED_OPS)
78-
self.assertEqual(op_info.non_delegated, BASE_EXPECTED_OPS[op])
79-
80-
self.assertEqual(len(delegation_info.delegation_by_operator), len(BASE_EXPECTED_OPS))
8150

0 commit comments

Comments
 (0)