Skip to content

[Backend Tester] Report quantization and lowering times #12838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/GregoryComer/91/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/test/suite/reporting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
from collections import Counter
from dataclasses import dataclass
from datetime import timedelta
from enum import IntEnum
from functools import reduce
from typing import TextIO
Expand Down Expand Up @@ -108,6 +109,12 @@ class TestCaseSummary:
a single output tensor.
"""

quantize_time: timedelta | None = None
""" The total runtime of the quantization stage, or none, if the test did not run the quantize stage. """

lower_time: timedelta | None = None
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """


class TestSessionState:
test_case_summaries: list[TestCaseSummary]
Expand Down Expand Up @@ -190,6 +197,8 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
"Backend",
"Flow",
"Result",
"Quantize Time (s)",
"Lowering Time (s)",
]

# Tests can have custom parameters. We'll want to report them here, so we need
Expand Down Expand Up @@ -230,6 +239,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
"Backend": record.backend,
"Flow": record.flow,
"Result": record.result.display_name(),
"Quantize Time (s)": (
record.quantize_time.total_seconds() if record.quantize_time else None
),
"Lowering Time (s)": (
record.lower_time.total_seconds() if record.lower_time else None
),
}
if record.params is not None:
row.update({k.capitalize(): v for k, v in record.params.items()})
Expand Down
15 changes: 15 additions & 0 deletions backends/test/suite/runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
import importlib
import re
import time
import unittest

from datetime import timedelta
from typing import Any

import torch
Expand Down Expand Up @@ -44,6 +46,7 @@ def run_test( # noqa: C901
"""

error_statistics: list[ErrorStatistics] = []
extra_stats = {}

# Helper method to construct the summary.
def build_result(
Expand All @@ -58,6 +61,7 @@ def build_result(
result=result,
error=error,
tensor_error_statistics=error_statistics,
**extra_stats,
)

# Ensure the model can run in eager mode.
Expand All @@ -72,11 +76,16 @@ def build_result(
return build_result(TestResult.UNKNOWN_FAIL, e)

if flow.quantize:
start_time = time.perf_counter()
try:
tester.quantize(
flow.quantize_stage_factory() if flow.quantize_stage_factory else None
)
elapsed = time.perf_counter() - start_time
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
except Exception as e:
elapsed = time.perf_counter() - start_time
extra_stats["quantize_time"] = timedelta(seconds=elapsed)
return build_result(TestResult.QUANTIZE_FAIL, e)

try:
Expand All @@ -87,9 +96,14 @@ def build_result(
except Exception as e:
return build_result(TestResult.EXPORT_FAIL, e)

lower_start_time = time.perf_counter()
try:
tester.to_edge_transform_and_lower()
elapsed = time.perf_counter() - lower_start_time
extra_stats["lower_time"] = timedelta(seconds=elapsed)
except Exception as e:
elapsed = time.perf_counter() - lower_start_time
extra_stats["lower_time"] = timedelta(seconds=elapsed)
return build_result(TestResult.LOWER_FAIL, e)

is_delegated = any(
Expand Down Expand Up @@ -185,6 +199,7 @@ def parse_args():
"--report",
nargs="?",
help="A file to write the test report to, in CSV format.",
default="backend_test_report.csv",
)
return parser.parse_args()

Expand Down
Loading