Skip to content

Commit 5563051

Browse files
authored
[Backend Tester] Report quantization and lowering times (#12838)
Track and report the time taken to quantize and lower in the backend test flow. Include this information in the generated report for each test case. Example output (from testing add operator): Test ID | Test Case | Backend | Flow | Result | Quantize Time (s) | Lowering Time (s) -- | -- | -- | -- | -- | -- | -- test_add_dtype_float32_coreml | test_add_dtype | coreml | coreml | Success (Delegated) |   | 0.69 test_add_dtype_float32_coreml_static_int8 | test_add_dtype | coreml | coreml_static_int8 | Success (Delegated) | 8.73 | 0.88
1 parent 010b800 commit 5563051

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

backends/test/suite/reporting.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import csv
22
from collections import Counter
33
from dataclasses import dataclass
4+
from datetime import timedelta
45
from enum import IntEnum
56
from functools import reduce
67
from typing import TextIO
@@ -108,6 +109,12 @@ class TestCaseSummary:
108109
a single output tensor.
109110
"""
110111

112+
quantize_time: timedelta | None = None
113+
""" The total runtime of the quantization stage, or none, if the test did not run the quantize stage. """
114+
115+
lower_time: timedelta | None = None
116+
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """
117+
111118

112119
class TestSessionState:
113120
test_case_summaries: list[TestCaseSummary]
@@ -190,6 +197,8 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
190197
"Backend",
191198
"Flow",
192199
"Result",
200+
"Quantize Time (s)",
201+
"Lowering Time (s)",
193202
]
194203

195204
# Tests can have custom parameters. We'll want to report them here, so we need
@@ -230,6 +239,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
230239
"Backend": record.backend,
231240
"Flow": record.flow,
232241
"Result": record.result.display_name(),
242+
"Quantize Time (s)": (
243+
record.quantize_time.total_seconds() if record.quantize_time else None
244+
),
245+
"Lowering Time (s)": (
246+
record.lower_time.total_seconds() if record.lower_time else None
247+
),
233248
}
234249
if record.params is not None:
235250
row.update({k.capitalize(): v for k, v in record.params.items()})

backends/test/suite/runner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import argparse
22
import importlib
33
import re
4+
import time
45
import unittest
56

7+
from datetime import timedelta
68
from typing import Any
79

810
import torch
@@ -44,6 +46,7 @@ def run_test( # noqa: C901
4446
"""
4547

4648
error_statistics: list[ErrorStatistics] = []
49+
extra_stats = {}
4750

4851
# Helper method to construct the summary.
4952
def build_result(
@@ -58,6 +61,7 @@ def build_result(
5861
result=result,
5962
error=error,
6063
tensor_error_statistics=error_statistics,
64+
**extra_stats,
6165
)
6266

6367
# Ensure the model can run in eager mode.
@@ -72,11 +76,16 @@ def build_result(
7276
return build_result(TestResult.UNKNOWN_FAIL, e)
7377

7478
if flow.quantize:
79+
start_time = time.perf_counter()
7580
try:
7681
tester.quantize(
7782
flow.quantize_stage_factory() if flow.quantize_stage_factory else None
7883
)
84+
elapsed = time.perf_counter() - start_time
85+
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
7986
except Exception as e:
87+
elapsed = time.perf_counter() - start_time
88+
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
8089
return build_result(TestResult.QUANTIZE_FAIL, e)
8190

8291
try:
@@ -87,9 +96,14 @@ def build_result(
8796
except Exception as e:
8897
return build_result(TestResult.EXPORT_FAIL, e)
8998

99+
lower_start_time = time.perf_counter()
90100
try:
91101
tester.to_edge_transform_and_lower()
102+
elapsed = time.perf_counter() - lower_start_time
103+
extra_stats["lower_time"] = timedelta(seconds=elapsed)
92104
except Exception as e:
105+
elapsed = time.perf_counter() - lower_start_time
106+
extra_stats["lower_time"] = timedelta(seconds=elapsed)
93107
return build_result(TestResult.LOWER_FAIL, e)
94108

95109
is_delegated = any(
@@ -185,6 +199,7 @@ def parse_args():
185199
"--report",
186200
nargs="?",
187201
help="A file to write the test report to, in CSV format.",
202+
default="backend_test_report.csv",
188203
)
189204
return parser.parse_args()
190205

0 commit comments

Comments
 (0)