Skip to content

Commit 330e3ce

Browse files
committed
[Backend Tester] Report quantization and lowering times
ghstack-source-id: cc237ca ghstack-comment-id: 3115355420 Pull-Request: #12838
1 parent 10f1d01 commit 330e3ce

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

backends/test/suite/reporting.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import csv
22
from collections import Counter
33
from dataclasses import dataclass
4+
from datetime import timedelta
45
from enum import IntEnum
56
from functools import reduce
67
from typing import TextIO
@@ -108,6 +109,12 @@ class TestCaseSummary:
108109
a single output tensor.
109110
"""
110111

112+
quantize_time: timedelta | None = None
113+
""" The total runtime of the quantization stage, or none, if the test did not run the quantize stage. """
114+
115+
lower_time: timedelta | None = None
116+
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """
117+
111118

112119
class TestSessionState:
113120
test_case_summaries: list[TestCaseSummary]
@@ -190,6 +197,8 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
190197
"Backend",
191198
"Flow",
192199
"Result",
200+
"Quantize Time (s)",
201+
"Lowering Time (s)",
193202
]
194203

195204
# Tests can have custom parameters. We'll want to report them here, so we need
@@ -230,6 +239,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
230239
"Backend": record.backend,
231240
"Flow": record.flow,
232241
"Result": record.result.display_name(),
242+
"Quantize Time (s)": (
243+
record.quantize_time.total_seconds() if record.quantize_time else None
244+
),
245+
"Lowering Time (s)": (
246+
record.lower_time.total_seconds() if record.lower_time else None
247+
),
233248
}
234249
if record.params is not None:
235250
row.update({k.capitalize(): v for k, v in record.params.items()})

backends/test/suite/runner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import argparse
22
import importlib
33
import re
4+
import time
45
import unittest
56

7+
from datetime import timedelta
68
from typing import Any
79

810
import torch
@@ -44,6 +46,7 @@ def run_test( # noqa: C901
4446
"""
4547

4648
error_statistics: list[ErrorStatistics] = []
49+
extra_stats = {}
4750

4851
# Helper method to construct the summary.
4952
def build_result(
@@ -58,6 +61,7 @@ def build_result(
5861
result=result,
5962
error=error,
6063
tensor_error_statistics=error_statistics,
64+
**extra_stats,
6165
)
6266

6367
# Ensure the model can run in eager mode.
@@ -72,11 +76,16 @@ def build_result(
7276
return build_result(TestResult.UNKNOWN_FAIL, e)
7377

7478
if flow.quantize:
79+
start_time = time.perf_counter()
7580
try:
7681
tester.quantize(
7782
flow.quantize_stage_factory() if flow.quantize_stage_factory else None
7883
)
84+
elapsed = time.perf_counter() - start_time
85+
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
7986
except Exception as e:
87+
elapsed = time.perf_counter() - start_time
88+
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
8089
return build_result(TestResult.QUANTIZE_FAIL, e)
8190

8291
try:
@@ -87,9 +96,14 @@ def build_result(
8796
except Exception as e:
8897
return build_result(TestResult.EXPORT_FAIL, e)
8998

99+
lower_start_time = time.perf_counter()
90100
try:
91101
tester.to_edge_transform_and_lower()
102+
elapsed = time.perf_counter() - lower_start_time
103+
extra_stats["lower_time"] = timedelta(seconds=elapsed)
92104
except Exception as e:
105+
elapsed = time.perf_counter() - lower_start_time
106+
extra_stats["lower_time"] = timedelta(seconds=elapsed)
93107
return build_result(TestResult.LOWER_FAIL, e)
94108

95109
is_delegated = any(
@@ -185,6 +199,7 @@ def parse_args():
185199
"--report",
186200
nargs="?",
187201
help="A file to write the test report to, in CSV format.",
202+
default="backend_test_report.csv",
188203
)
189204
return parser.parse_args()
190205

0 commit comments

Comments
 (0)