Skip to content

Commit 106c8f2

Browse files
committed
[Backend Tester] Report quantization and lowering times
ghstack-source-id: 22aa30d ghstack-comment-id: 3115355420 Pull-Request: #12838
1 parent bd04bcd commit 106c8f2

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

backends/test/suite/reporting.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import csv
22
from collections import Counter
33
from dataclasses import dataclass
4+
from datetime import timedelta
45
from enum import IntEnum
56
from functools import reduce
67
from typing import TextIO
@@ -108,6 +109,12 @@ class TestCaseSummary:
108109
a single output tensor.
109110
"""
110111

112+
quantize_time: timedelta | None = None
113+
""" The total runtime of the quantization stage, or none, if the test did not run the quantize stage. """
114+
115+
lower_time: timedelta | None = None
116+
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """
117+
111118

112119
class TestSessionState:
113120
test_case_summaries: list[TestCaseSummary]
@@ -190,6 +197,8 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
190197
"Backend",
191198
"Flow",
192199
"Result",
200+
"Quantize Time (s)",
201+
"Lowering Time (s)",
193202
]
194203

195204
# Tests can have custom parameters. We'll want to report them here, so we need
@@ -230,6 +239,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
230239
"Backend": record.backend,
231240
"Flow": record.flow,
232241
"Result": record.result.display_name(),
242+
"Quantize Time (s)": (
243+
record.quantize_time.total_seconds() if record.quantize_time else None
244+
),
245+
"Lowering Time (s)": (
246+
record.lower_time.total_seconds() if record.lower_time else None
247+
),
233248
}
234249
if record.params is not None:
235250
row.update({k.capitalize(): v for k, v in record.params.items()})

backends/test/suite/runner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import argparse
22
import importlib
33
import re
4+
import time
45
import unittest
56

7+
from datetime import timedelta
68
from typing import Any
79

810
import torch
@@ -43,6 +45,7 @@ def run_test( # noqa: C901
4345
"""
4446

4547
error_statistics: list[ErrorStatistics] = []
48+
extra_stats = {}
4649

4750
# Helper method to construct the summary.
4851
def build_result(
@@ -57,6 +60,7 @@ def build_result(
5760
result=result,
5861
error=error,
5962
tensor_error_statistics=error_statistics,
63+
**extra_stats,
6064
)
6165

6266
# Ensure the model can run in eager mode.
@@ -71,11 +75,16 @@ def build_result(
7175
return build_result(TestResult.UNKNOWN_FAIL, e)
7276

7377
if flow.quantize:
78+
start_time = time.perf_counter()
7479
try:
7580
tester.quantize(
7681
flow.quantize_stage_factory() if flow.quantize_stage_factory else None
7782
)
83+
elapsed = time.perf_counter() - start_time
84+
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
7885
except Exception as e:
86+
elapsed = time.perf_counter() - start_time
87+
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
7988
return build_result(TestResult.QUANTIZE_FAIL, e)
8089

8190
try:
@@ -86,9 +95,14 @@ def build_result(
8695
except Exception as e:
8796
return build_result(TestResult.EXPORT_FAIL, e)
8897

98+
lower_start_time = time.perf_counter()
8999
try:
90100
tester.to_edge_transform_and_lower()
101+
elapsed = time.perf_counter() - lower_start_time
102+
extra_stats["lower_time"] = timedelta(seconds=elapsed)
91103
except Exception as e:
104+
elapsed = time.perf_counter() - lower_start_time
105+
extra_stats["lower_time"] = timedelta(seconds=elapsed)
92106
return build_result(TestResult.LOWER_FAIL, e)
93107

94108
is_delegated = any(
@@ -183,6 +197,7 @@ def parse_args():
183197
"--report",
184198
nargs="?",
185199
help="A file to write the test report to, in CSV format.",
200+
default="backend_test_report.csv",
186201
)
187202
return parser.parse_args()
188203

0 commit comments

Comments
 (0)