Skip to content

Commit 6eba97f

Browse files
committed
[Backend Tester] Report delegation statistics
ghstack-source-id: 0604da3 ghstack-comment-id: 3115647824 Pull-Request: #12846
1 parent 7981ed1 commit 6eba97f

File tree

7 files changed

+152
-9
lines changed

7 files changed

+152
-9
lines changed

backends/qualcomm/tests/tester.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def __init__(
5252
default_partitioner_cls=QnnPartitioner,
5353
)
5454

55-
def run(self, artifact: ExportedProgram, inputs=None) -> None:
55+
def run(
56+
self, artifact: ExportedProgram, inputs=None, generate_etrecord: bool = False
57+
) -> None:
5658
ep = QnnPassManager().transform_for_export_pipeline(artifact)
5759
transform_passes = QnnPassManager().get_to_edge_transform_passes(ep)
5860

@@ -61,6 +63,7 @@ def run(self, artifact: ExportedProgram, inputs=None) -> None:
6163
transform_passes=transform_passes,
6264
partitioner=self.partitioners,
6365
compile_config=self.edge_compile_conf,
66+
generate_etrecord=generate_etrecord,
6467
)
6568

6669

backends/test/harness/stages/to_edge_transform_and_lower.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
to_edge_transform_and_lower,
88
)
99
from executorch.exir.backend.partitioner import Partitioner
10+
11+
from sympy.ntheory import generate
1012
from torch.export import ExportedProgram
1113

1214

@@ -24,11 +26,14 @@ def __init__(
2426
def stage_type(self) -> StageType:
2527
return StageType.TO_EDGE_TRANSFORM_AND_LOWER
2628

27-
def run(self, artifact: ExportedProgram, inputs=None) -> None:
29+
def run(
30+
self, artifact: ExportedProgram, inputs=None, generate_etrecord: bool = False
31+
) -> None:
2832
self.edge_dialect_program = to_edge_transform_and_lower(
2933
artifact,
3034
compile_config=self.edge_compile_conf,
3135
partitioner=self.partitioners,
36+
generate_etrecord=generate_etrecord,
3237
)
3338

3439
@property

backends/test/harness/tester.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ def _post(self, stage):
183183
assert stage_type in self.stages
184184
self.stages[stage_type] = stage
185185

186-
def _run_stage(self, stage_instance, inputs=None):
186+
def _run_stage(self, stage_instance, inputs=None, *args, **kwargs):
187187
assert isinstance(stage_instance, Stage)
188188
prev_stage_artifact = self._pre(stage_instance)
189-
stage_instance.run(prev_stage_artifact, inputs=inputs)
189+
stage_instance.run(prev_stage_artifact, inputs=inputs, *args, **kwargs)
190190
self._post(stage_instance)
191191
return self
192192

@@ -213,11 +213,14 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
213213
return res
214214

215215
def to_edge_transform_and_lower(
216-
self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None
216+
self,
217+
to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None,
218+
generate_etrecord: bool = False,
217219
):
218220
return self._run_stage(
219221
to_edge_and_transform_stage
220-
or self._get_default_stage(StageType.TO_EDGE_TRANSFORM_AND_LOWER)
222+
or self._get_default_stage(StageType.TO_EDGE_TRANSFORM_AND_LOWER),
223+
generate_etrecord=generate_etrecord,
221224
)
222225

223226
def run_passes(self, run_passes_stage: Optional[RunPasses] = None):

backends/test/suite/reporting.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import csv
2+
23
from collections import Counter
34
from dataclasses import dataclass
45
from datetime import timedelta
56
from enum import IntEnum
67
from functools import reduce
7-
from typing import TextIO
8+
from typing import Any, TextIO
89

910
from executorch.backends.test.harness.error_statistics import ErrorStatistics
11+
from torch.export import ExportedProgram
12+
13+
14+
# Operators that are excluded from the counts returned by count_ops. These are used to
15+
# exclude operatations that are not logically relevant or delegatable to backends.
16+
OP_COUNT_IGNORED_OPS = {
17+
"executorch_call_delegate",
18+
"getitem",
19+
}
1020

1121

1222
class TestResult(IntEnum):
@@ -115,6 +125,12 @@ class TestCaseSummary:
115125
lower_time: timedelta | None = None
116126
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """
117127

128+
delegated_op_counts: Counter | None = None
129+
""" The number of delegated occurances of each operator in the graph. """
130+
131+
undelegated_op_counts: Counter | None = None
132+
""" The number of undelegated occurances of each operator in the graph. """
133+
118134

119135
class TestSessionState:
120136
test_case_summaries: list[TestCaseSummary]
@@ -164,6 +180,40 @@ def from_session(cls, session: TestSessionState) -> "RunSummary":
164180
_active_session: TestSessionState | None = None
165181

166182

183+
def _get_target_name(target: Any) -> str:
184+
"""Retrieve a string representation of a node target."""
185+
if isinstance(target, str):
186+
return target
187+
elif hasattr(target, "name"):
188+
return target.name() # Op overloads have this
189+
elif hasattr(target, "__name__"):
190+
return target.__name__ # Some builtins have this
191+
else:
192+
return str(target)
193+
194+
195+
def _count_ops(program: ExportedProgram) -> Counter:
196+
op_names = (
197+
_get_target_name(n.target)
198+
for n in program.graph.nodes
199+
if n.op == "call_function"
200+
)
201+
202+
return Counter(op for op in op_names if op not in OP_COUNT_IGNORED_OPS)
203+
204+
205+
def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
206+
if isinstance(program, ExportedProgram):
207+
return _count_ops(program)
208+
else:
209+
# Sum op counts for all methods in the program.
210+
return reduce(
211+
lambda a, b: a + b,
212+
(_count_ops(p) for p in program.values()),
213+
Counter(),
214+
)
215+
216+
167217
def begin_test_session():
168218
global _active_session
169219

@@ -188,6 +238,24 @@ def complete_test_session() -> RunSummary:
188238
return summary
189239

190240

241+
def _sum_op_counts(counter: Counter | None) -> int | None:
242+
"""
243+
A utility function to count the total number of nodes in an op count dict.
244+
"""
245+
return sum(counter.values()) if counter is not None else None
246+
247+
248+
def _serialize_op_counts(counter: Counter | None) -> str:
249+
"""
250+
A utility function to serialize op counts to a string, for the purpose of including
251+
in the test report.
252+
"""
253+
if counter is not None:
254+
return str(dict(sorted(counter.items())))
255+
else:
256+
return ""
257+
258+
191259
def generate_csv_report(summary: RunSummary, output: TextIO):
192260
"""Write a run summary report to a file in CSV format."""
193261

@@ -228,6 +296,14 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
228296
f"Output {i} SQNR",
229297
]
230298
)
299+
field_names.extend(
300+
[
301+
"Delegated Nodes",
302+
"Undelegated Nodes",
303+
"Delegated Ops",
304+
"Undelegated Ops",
305+
]
306+
)
231307

232308
writer = csv.DictWriter(output, field_names)
233309
writer.writeheader()
@@ -256,4 +332,9 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
256332
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
257333
row[f"Output {output_idx} SQNR"] = error_stats.sqnr
258334

335+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
336+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
337+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
338+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
339+
259340
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
from executorch.backends.test.suite.reporting import (
1717
begin_test_session,
1818
complete_test_session,
19+
count_ops,
1920
generate_csv_report,
2021
RunSummary,
2122
TestCaseSummary,
2223
TestResult,
2324
)
25+
from executorch.exir import EdgeProgramManager
2426

2527

2628
# A list of all runnable test suites and the corresponding python package.
@@ -98,14 +100,25 @@ def build_result(
98100

99101
lower_start_time = time.perf_counter()
100102
try:
101-
tester.to_edge_transform_and_lower()
103+
tester.to_edge_transform_and_lower(generate_etrecord=True)
102104
elapsed = time.perf_counter() - lower_start_time
103105
extra_stats["lower_time"] = timedelta(seconds=elapsed)
104106
except Exception as e:
105107
elapsed = time.perf_counter() - lower_start_time
106108
extra_stats["lower_time"] = timedelta(seconds=elapsed)
107109
return build_result(TestResult.LOWER_FAIL, e)
108110

111+
# Compute delegation statistics. Use the ETRecord to access the edge dialect graph between
112+
# to_edge and delegation. Note that ETRecord only stores the edge dialect graph for a single
113+
# method currently and assumes it is called "forward".
114+
edge_manager: EdgeProgramManager = tester.get_artifact()
115+
edge_op_counts = count_ops({"forward": edge_manager._etrecord.edge_dialect_program})
116+
undelegated_op_counts = count_ops(edge_manager._edge_programs)
117+
delegated_op_counts = edge_op_counts - undelegated_op_counts
118+
119+
extra_stats["delegated_op_counts"] = delegated_op_counts
120+
extra_stats["undelegated_op_counts"] = undelegated_op_counts
121+
109122
is_delegated = any(
110123
n.target == torch._higher_order_ops.executorch_call_delegate
111124
for n in tester.stages[tester.cur].graph_module.graph.nodes
@@ -127,7 +140,7 @@ def build_result(
127140
try:
128141
tester.run_method_and_compare_outputs(
129142
inputs=None if generate_random_test_inputs else inputs,
130-
statistics_callback=lambda stats: error_statistics.append(stats)
143+
statistics_callback=lambda stats: error_statistics.append(stats),
131144
)
132145
except AssertionError as e:
133146
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)

backends/test/suite/tests/test_reporting.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
import torch
77

8+
from executorch.exir import to_edge
9+
810
from ..reporting import (
11+
count_ops,
912
generate_csv_report,
1013
RunSummary,
1114
TestCaseSummary,
@@ -23,6 +26,7 @@
2326
params=None,
2427
result=TestResult.SUCCESS,
2528
error=None,
29+
tensor_error_statistics=[],
2630
),
2731
TestCaseSummary(
2832
backend="backend2",
@@ -32,6 +36,7 @@
3236
params=None,
3337
result=TestResult.LOWER_FAIL,
3438
error=None,
39+
tensor_error_statistics=[],
3540
),
3641
TestCaseSummary(
3742
backend="backend1",
@@ -41,6 +46,7 @@
4146
params={"dtype": torch.float32},
4247
result=TestResult.SUCCESS_UNDELEGATED,
4348
error=None,
49+
tensor_error_statistics=[],
4450
),
4551
TestCaseSummary(
4652
backend="backend2",
@@ -50,6 +56,7 @@
5056
params={"use_dynamic_shapes": True},
5157
result=TestResult.EXPORT_FAIL,
5258
error=None,
59+
tensor_error_statistics=[],
5360
),
5461
]
5562

@@ -104,3 +111,32 @@ def test_csv_report_simple(self):
104111
self.assertEqual(records[3]["Result"], "Fail (Export)")
105112
self.assertEqual(records[3]["Dtype"], "")
106113
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
114+
115+
def test_count_ops(self):
116+
"""
117+
Verify that the count_ops function correctly counts operator occurances in the edge graph.
118+
"""
119+
120+
class Model1(torch.nn.Module):
121+
def forward(self, x, y):
122+
return x + y
123+
124+
class Model2(torch.nn.Module):
125+
def forward(self, x, y):
126+
return x + y * y
127+
128+
args = (torch.randn(2), torch.randn(2))
129+
ep1 = torch.export.export(Model1(), args)
130+
ep2 = torch.export.export(Model2(), args)
131+
132+
ep = to_edge({"forward1": ep1, "forward2": ep2})
133+
134+
op_counts = count_ops(ep._edge_programs)
135+
136+
self.assertEqual(
137+
op_counts,
138+
{
139+
"aten::add.Tensor": 2,
140+
"aten::mul.Tensor": 1,
141+
},
142+
)

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ addopts =
4848
# is stable and signal to noise ratio is good (no irrelevant failures).
4949
# See https://github.com/pytorch/executorch/discussions/11140
5050
--ignore=backends/test
51+
backends/test/harness/tests
52+
backends/test/suite/tests
5153
# backends/xnnpack
5254
backends/xnnpack/test/ops
5355
--ignore=backends/xnnpack/test/ops/test_bmm.py

0 commit comments

Comments
 (0)