Skip to content

Commit dd62a07

Browse files
committed
[Backend Tester] Report delegation statistics
ghstack-source-id: cc7e564 ghstack-comment-id: 3115647824 Pull-Request: pytorch#12846
1 parent 1809933 commit dd62a07

File tree

5 files changed

+139
-2
lines changed

5 files changed

+139
-2
lines changed

backends/test/suite/reporting.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import csv
2+
23
from collections import Counter
34
from dataclasses import dataclass
45
from datetime import timedelta
56
from enum import IntEnum
67
from functools import reduce
7-
from typing import TextIO
8+
from typing import Any, TextIO
89

910
from executorch.backends.test.harness.error_statistics import ErrorStatistics
11+
from torch.export import ExportedProgram
12+
13+
14+
# Operators that are excluded from the counts returned by count_ops. These are used to
15+
# exclude operatations that are not logically relevant or delegatable to backends.
16+
OP_COUNT_IGNORED_OPS = {
17+
"executorch_call_delegate",
18+
"getitem",
19+
}
1020

1121

1222
class TestResult(IntEnum):
@@ -115,6 +125,12 @@ class TestCaseSummary:
115125
lower_time: timedelta | None = None
116126
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """
117127

128+
delegated_op_counts: Counter | None = None
129+
""" The number of delegated occurances of each operator in the graph. """
130+
131+
undelegated_op_counts: Counter | None = None
132+
""" The number of undelegated occurances of each operator in the graph. """
133+
118134

119135
class TestSessionState:
120136
test_case_summaries: list[TestCaseSummary]
@@ -164,6 +180,40 @@ def from_session(cls, session: TestSessionState) -> "RunSummary":
164180
_active_session: TestSessionState | None = None
165181

166182

183+
def _get_target_name(target: Any) -> str:
184+
"""Retrieve a string representation of a node target."""
185+
if isinstance(target, str):
186+
return target
187+
elif hasattr(target, "name"):
188+
return target.name() # Op overloads have this
189+
elif hasattr(target, "__name__"):
190+
return target.__name__ # Some builtins have this
191+
else:
192+
return str(target)
193+
194+
195+
def _count_ops(program: ExportedProgram) -> Counter:
196+
op_names = (
197+
_get_target_name(n.target)
198+
for n in program.graph.nodes
199+
if n.op == "call_function"
200+
)
201+
202+
return Counter(op for op in op_names if op not in OP_COUNT_IGNORED_OPS)
203+
204+
205+
def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
206+
if isinstance(program, ExportedProgram):
207+
return _count_ops(program)
208+
else:
209+
# Sum op counts for all methods in the program.
210+
return reduce(
211+
lambda a, b: a + b,
212+
(_count_ops(p) for p in program.values()),
213+
Counter(),
214+
)
215+
216+
167217
def begin_test_session():
168218
global _active_session
169219

@@ -188,6 +238,24 @@ def complete_test_session() -> RunSummary:
188238
return summary
189239

190240

241+
def _sum_op_counts(counter: Counter | None) -> int | None:
242+
"""
243+
A utility function to count the total number of nodes in an op count dict.
244+
"""
245+
return sum(counter.values()) if counter is not None else None
246+
247+
248+
def _serialize_op_counts(counter: Counter | None) -> str:
249+
"""
250+
A utility function to serialize op counts to a string, for the purpose of including
251+
in the test report.
252+
"""
253+
if counter is not None:
254+
return str(dict(sorted(counter.items())))
255+
else:
256+
return ""
257+
258+
191259
def generate_csv_report(summary: RunSummary, output: TextIO):
192260
"""Write a run summary report to a file in CSV format."""
193261

@@ -228,6 +296,14 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
228296
f"Output {i} SQNR",
229297
]
230298
)
299+
field_names.extend(
300+
[
301+
"Delegated Nodes",
302+
"Undelegated Nodes",
303+
"Delegated Ops",
304+
"Undelegated Ops",
305+
]
306+
)
231307

232308
writer = csv.DictWriter(output, field_names)
233309
writer.writeheader()
@@ -256,4 +332,9 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
256332
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
257333
row[f"Output {output_idx} SQNR"] = error_stats.sqnr
258334

335+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
336+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
337+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
338+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
339+
259340
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
from executorch.backends.test.suite.reporting import (
1717
begin_test_session,
1818
complete_test_session,
19+
count_ops,
1920
generate_csv_report,
2021
RunSummary,
2122
TestCaseSummary,
2223
TestResult,
2324
)
25+
from executorch.exir import EdgeProgramManager
2426

2527

2628
# A list of all runnable test suites and the corresponding python package.
@@ -106,6 +108,14 @@ def build_result(
106108
extra_stats["lower_time"] = timedelta(seconds=elapsed)
107109
return build_result(TestResult.LOWER_FAIL, e)
108110

111+
edge_manager: EdgeProgramManager = tester.get_artifact()
112+
edge_op_counts = count_ops(edge_manager.original_edge_programs)
113+
undelegated_op_counts = count_ops(edge_manager._edge_programs)
114+
delegated_op_counts = edge_op_counts - undelegated_op_counts
115+
116+
extra_stats["delegated_op_counts"] = delegated_op_counts
117+
extra_stats["undelegated_op_counts"] = undelegated_op_counts
118+
109119
is_delegated = any(
110120
n.target == torch._higher_order_ops.executorch_call_delegate
111121
for n in tester.stages[tester.cur].graph_module.graph.nodes

backends/test/suite/tests/test_reporting.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
import torch
77

8+
from executorch.exir import to_edge
9+
810
from ..reporting import (
11+
count_ops,
912
generate_csv_report,
1013
RunSummary,
1114
TestCaseSummary,
@@ -23,6 +26,7 @@
2326
params=None,
2427
result=TestResult.SUCCESS,
2528
error=None,
29+
tensor_error_statistics=[],
2630
),
2731
TestCaseSummary(
2832
backend="backend2",
@@ -32,6 +36,7 @@
3236
params=None,
3337
result=TestResult.LOWER_FAIL,
3438
error=None,
39+
tensor_error_statistics=[],
3540
),
3641
TestCaseSummary(
3742
backend="backend1",
@@ -41,6 +46,7 @@
4146
params={"dtype": torch.float32},
4247
result=TestResult.SUCCESS_UNDELEGATED,
4348
error=None,
49+
tensor_error_statistics=[],
4450
),
4551
TestCaseSummary(
4652
backend="backend2",
@@ -50,6 +56,7 @@
5056
params={"use_dynamic_shapes": True},
5157
result=TestResult.EXPORT_FAIL,
5258
error=None,
59+
tensor_error_statistics=[],
5360
),
5461
]
5562

@@ -104,3 +111,32 @@ def test_csv_report_simple(self):
104111
self.assertEqual(records[3]["Result"], "Fail (Export)")
105112
self.assertEqual(records[3]["Dtype"], "")
106113
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
114+
115+
def test_count_ops(self):
116+
"""
117+
Verify that the count_ops function correctly counts operator occurances in the edge graph.
118+
"""
119+
120+
class Model1(torch.nn.Module):
121+
def forward(self, x, y):
122+
return x + y
123+
124+
class Model2(torch.nn.Module):
125+
def forward(self, x, y):
126+
return x + y * y
127+
128+
args = (torch.randn(2), torch.randn(2))
129+
ep1 = torch.export.export(Model1(), args)
130+
ep2 = torch.export.export(Model2(), args)
131+
132+
ep = to_edge({"forward1": ep1, "forward2": ep2})
133+
134+
op_counts = count_ops(ep._edge_programs)
135+
136+
self.assertEqual(
137+
op_counts,
138+
{
139+
"aten::add.Tensor": 2,
140+
"aten::mul.Tensor": 1,
141+
},
142+
)

exir/program/_program.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,7 @@ def _gen_edge_manager_for_partitioners(
11791179
config,
11801180
list(set().union(*ops_set_to_not_decompose_by_program.values())),
11811181
)
1182+
11821183
return edge_manager
11831184

11841185

@@ -1410,6 +1411,8 @@ class EdgeProgramManager:
14101411
Manages the second link in the lowering chain of ATen -> Edge -> ExecuTorch.
14111412
"""
14121413

1414+
original_edge_programs: dict[str, ExportedProgram] | None = None
1415+
14131416
def __init__(
14141417
self,
14151418
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
@@ -1558,12 +1561,17 @@ def to_backend(
15581561

15591562
new_edge_programs = to_backend(method_to_programs_and_partitioners)
15601563
config = EdgeCompileConfig(_check_ir_validity=False)
1561-
return EdgeProgramManager(
1564+
new_edge_manager = EdgeProgramManager(
15621565
new_edge_programs,
15631566
copy.deepcopy(self._config_methods),
15641567
config,
15651568
)
15661569

1570+
# Placeholder - not for land
1571+
new_edge_manager.original_edge_programs = copy.deepcopy(self._edge_programs)
1572+
1573+
return new_edge_manager
1574+
15671575
@et_logger("to_executorch")
15681576
def to_executorch(
15691577
self,

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ addopts =
4848
# is stable and signal to noise ratio is good (no irrelevant failures).
4949
# See https://github.com/pytorch/executorch/discussions/11140
5050
--ignore=backends/test
51+
backends/test/harness/tests
52+
backends/test/suite/tests
5153
# backends/xnnpack
5254
backends/xnnpack/test/ops
5355
--ignore=backends/xnnpack/test/ops/test_bmm.py

0 commit comments

Comments
 (0)