Skip to content

[Backend Tester] Report quantization and lowering times #12838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 82 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
f120e70
Update
GregoryComer Jul 18, 2025
0fb85e6
Update
GregoryComer Jul 18, 2025
4d8d844
Update
GregoryComer Jul 19, 2025
dc12b40
Update
GregoryComer Jul 21, 2025
ead0616
Update
GregoryComer Jul 22, 2025
0f13676
Update
GregoryComer Jul 22, 2025
b0b01f2
Update
GregoryComer Jul 22, 2025
8b9c9ef
Update
GregoryComer Jul 22, 2025
06bf03a
Update
GregoryComer Jul 22, 2025
2f8f49b
Update
GregoryComer Jul 22, 2025
8ca7766
Update
GregoryComer Jul 22, 2025
bffb95f
Update
GregoryComer Jul 22, 2025
d21492b
Update
GregoryComer Jul 22, 2025
e2c4ea5
Update
GregoryComer Jul 22, 2025
8230848
Update
GregoryComer Jul 22, 2025
2a1f564
Update
GregoryComer Jul 22, 2025
b35e7b1
Update
GregoryComer Jul 22, 2025
5c4c6ce
Update
GregoryComer Jul 22, 2025
9397803
Update
GregoryComer Jul 22, 2025
9dfeb5a
Update
GregoryComer Jul 22, 2025
ff5c4a5
Update
GregoryComer Jul 22, 2025
42a5de5
Update
GregoryComer Jul 22, 2025
402d8f5
Update
GregoryComer Jul 22, 2025
34d3ab3
Update
GregoryComer Jul 22, 2025
1105e04
Update
GregoryComer Jul 22, 2025
482bd21
Update
GregoryComer Jul 22, 2025
ea548b7
Update
GregoryComer Jul 23, 2025
4108f54
Update
GregoryComer Jul 23, 2025
7ef236b
Update
GregoryComer Jul 23, 2025
4a58c9d
Update
GregoryComer Jul 23, 2025
3b866b4
Update
GregoryComer Jul 23, 2025
5ba25cb
Update
GregoryComer Jul 23, 2025
19760fc
Update
GregoryComer Jul 23, 2025
81dfb07
Update
GregoryComer Jul 23, 2025
4d50265
Update
GregoryComer Jul 23, 2025
5f66043
Update
GregoryComer Jul 23, 2025
24e919d
Update
GregoryComer Jul 23, 2025
523cc20
Update
GregoryComer Jul 23, 2025
74c95fe
Update
GregoryComer Jul 23, 2025
5d437b1
Update
GregoryComer Jul 23, 2025
89757ce
Update
GregoryComer Jul 23, 2025
423f79a
Update
GregoryComer Jul 23, 2025
69f7f9c
Update
GregoryComer Jul 23, 2025
c0f6224
Update
GregoryComer Jul 23, 2025
e2ea2a3
Update
GregoryComer Jul 23, 2025
7a2fab5
Update
GregoryComer Jul 23, 2025
033c231
Update
GregoryComer Jul 23, 2025
a9ed762
Update
GregoryComer Jul 23, 2025
64b174a
Update
GregoryComer Jul 23, 2025
3976629
Update
GregoryComer Jul 23, 2025
27cd171
Update
GregoryComer Jul 23, 2025
7bdd3e5
Update
GregoryComer Jul 23, 2025
b1254cd
Update
GregoryComer Jul 23, 2025
f2e2289
Update
GregoryComer Jul 23, 2025
cdd15c1
Update
GregoryComer Jul 23, 2025
e2df06e
Update
GregoryComer Jul 23, 2025
4461bd8
Update
GregoryComer Jul 23, 2025
7e97fd0
Update
GregoryComer Jul 23, 2025
bcb697c
Update
GregoryComer Jul 23, 2025
11a5a02
Update
GregoryComer Jul 24, 2025
244b146
Update
GregoryComer Jul 24, 2025
de21ac2
Update
GregoryComer Jul 24, 2025
fd26fc7
Update
GregoryComer Jul 24, 2025
4ae840d
Update
GregoryComer Jul 24, 2025
710ea49
Update
GregoryComer Jul 24, 2025
32f54b0
Update
GregoryComer Jul 24, 2025
a27d18c
Update
GregoryComer Jul 24, 2025
2eb59fc
Update
GregoryComer Jul 24, 2025
5cc4941
Update
GregoryComer Jul 24, 2025
ef7af5c
Update
GregoryComer Jul 24, 2025
18e89c1
Update
GregoryComer Jul 24, 2025
dd09555
Update
GregoryComer Aug 8, 2025
f1db3a0
Update
GregoryComer Aug 8, 2025
e0700b2
Update
GregoryComer Aug 8, 2025
f260b50
Update
GregoryComer Aug 8, 2025
f261355
Update
GregoryComer Aug 11, 2025
c3a24f9
Update
GregoryComer Aug 11, 2025
1697cbc
Update
GregoryComer Aug 11, 2025
b94b45e
Update
GregoryComer Aug 11, 2025
7e1a002
Update
GregoryComer Aug 12, 2025
a628d29
Update
GregoryComer Aug 12, 2025
1d34f49
Update
GregoryComer Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions backends/test/harness/error_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from dataclasses import dataclass

import torch
from torch.ao.ns.fx.utils import compute_sqnr


@dataclass
class TensorStatistics:
"""Contains summary statistics for a tensor."""

shape: torch.Size
""" The shape of the tensor. """

numel: int
""" The number of elements in the tensor. """

median: float
""" The median of the tensor. """

mean: float
""" The mean of the tensor. """

max: torch.types.Number
""" The maximum element of the tensor. """

min: torch.types.Number
""" The minimum element of the tensor. """

@classmethod
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
"""Creates a TensorStatistics object from a tensor."""
flattened = torch.flatten(tensor)
return cls(
shape=tensor.shape,
numel=tensor.numel(),
median=torch.quantile(flattened, q=0.5).item(),
mean=flattened.mean().item(),
max=flattened.max().item(),
min=flattened.min().item(),
)


@dataclass
class ErrorStatistics:
"""Contains statistics derived from the difference of two tensors."""

reference_stats: TensorStatistics
""" Statistics for the reference tensor. """

actual_stats: TensorStatistics
""" Statistics for the actual tensor. """

error_l2_norm: float | None
""" The L2 norm of the error between the actual and reference tensor. """

error_mae: float | None
""" The mean absolute error between the actual and reference tensor. """

error_max: float | None
""" The maximum absolute elementwise error between the actual and reference tensor. """

error_msd: float | None
""" The mean signed deviation between the actual and reference tensor. """

sqnr: float | None
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """

@classmethod
def from_tensors(
cls, actual: torch.Tensor, reference: torch.Tensor
) -> "ErrorStatistics":
"""Creates an ErrorStatistics object from two tensors."""
actual = actual.to(torch.float64)
reference = reference.to(torch.float64)

if actual.shape != reference.shape:
return cls(
reference_stats=TensorStatistics.from_tensor(reference),
actual_stats=TensorStatistics.from_tensor(actual),
error_l2_norm=None,
error_mae=None,
error_max=None,
error_msd=None,
sqnr=None,
)

error = actual - reference
flat_error = torch.flatten(error)

return cls(
reference_stats=TensorStatistics.from_tensor(reference),
actual_stats=TensorStatistics.from_tensor(actual),
error_l2_norm=torch.linalg.norm(flat_error).item(),
error_mae=torch.mean(torch.abs(flat_error)).item(),
error_max=torch.max(torch.abs(flat_error)).item(),
error_msd=torch.mean(flat_error).item(),
# Torch sqnr implementation requires float32 due to decorator logic
sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(),
)
33 changes: 24 additions & 9 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from executorch.backends.test.harness.error_statistics import ErrorStatistics
from executorch.backends.test.harness.stages import (
Export,
Partition,
Expand Down Expand Up @@ -302,20 +303,15 @@ def run_method_and_compare_outputs(
atol=1e-03,
rtol=1e-03,
qtol=0,
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
):
number_of_runs = 1 if inputs is not None else num_runs
reference_stage = self.stages[StageType.EXPORT]

stage = stage or self.cur

print(f"Comparing Stage {stage} with Stage {reference_stage}")
for run_iteration in range(number_of_runs):
for _ in range(number_of_runs):
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
input_shapes = [
generated_input.shape if hasattr(generated_input, "shape") else None
for generated_input in inputs_to_run
]
print(f"Run {run_iteration} with input shapes: {input_shapes}")

# Reference output (and quantization scale)
(
Expand All @@ -328,13 +324,25 @@ def run_method_and_compare_outputs(
# Output from running artifact at stage
stage_output = self.stages[stage].run_artifact(inputs_to_run)
self._compare_outputs(
reference_output, stage_output, quantization_scale, atol, rtol, qtol
reference_output,
stage_output,
quantization_scale,
atol,
rtol,
qtol,
statistics_callback,
)

return self

@staticmethod
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
def _assert_outputs_equal(
model_output,
ref_output,
atol=1e-03,
rtol=1e-03,
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
):
"""
Helper testing function that asserts that the model output and the reference output
are equal with some tolerance. Due to numerical differences between eager mode and
Expand All @@ -349,6 +357,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
for i in range(len(model_output)):
model = model_output[i]
ref = ref_output[i]

error_stats = ErrorStatistics.from_tensors(model, ref)
if statistics_callback is not None:
statistics_callback(error_stats)

assert (
ref.shape == model.shape
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
Expand Down Expand Up @@ -386,6 +399,7 @@ def _compare_outputs(
atol=1e-03,
rtol=1e-03,
qtol=0,
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
):
"""
Compares the original of the original nn module with the output of the generated artifact.
Expand All @@ -408,6 +422,7 @@ def _compare_outputs(
reference_output,
atol=atol,
rtol=rtol,
statistics_callback=statistics_callback,
)

@staticmethod
Expand Down
65 changes: 65 additions & 0 deletions backends/test/harness/tests/test_error_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest

import torch
from executorch.backends.test.harness.error_statistics import ErrorStatistics


class ErrorStatisticsTests(unittest.TestCase):
def test_error_stats_simple(self):
tensor1 = torch.tensor([1, 2, 3, 4])
tensor2 = torch.tensor([2, 2, 2, 5])

error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)

# Check actual tensor statistics
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
self.assertEqual(error_stats.actual_stats.numel, 4)
self.assertEqual(error_stats.actual_stats.median, 2.5)
self.assertEqual(error_stats.actual_stats.mean, 2.5)
self.assertEqual(error_stats.actual_stats.max, 4)
self.assertEqual(error_stats.actual_stats.min, 1)

# Check reference tensor statistics
self.assertEqual(error_stats.reference_stats.shape, torch.Size([4]))
self.assertEqual(error_stats.reference_stats.numel, 4)
self.assertEqual(error_stats.reference_stats.median, 2.0)
self.assertEqual(error_stats.reference_stats.mean, 2.75)
self.assertEqual(error_stats.reference_stats.max, 5)
self.assertEqual(error_stats.reference_stats.min, 2)

# Check error statistics
self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3)
self.assertEqual(error_stats.error_mae, 0.75)
self.assertEqual(error_stats.error_max, 1.0)
self.assertEqual(error_stats.error_msd, -0.25)
self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3)

def test_error_stats_different_shapes(self):
# Create tensors with different shapes
tensor1 = torch.tensor([1, 2, 3, 4])
tensor2 = torch.tensor([[2, 3], [4, 5]])

error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)

# Check actual tensor statistics
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
self.assertEqual(error_stats.actual_stats.numel, 4)
self.assertEqual(error_stats.actual_stats.median, 2.5)
self.assertEqual(error_stats.actual_stats.mean, 2.5)
self.assertEqual(error_stats.actual_stats.max, 4)
self.assertEqual(error_stats.actual_stats.min, 1)

# Check reference tensor statistics
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2]))
self.assertEqual(error_stats.reference_stats.numel, 4)
self.assertEqual(error_stats.reference_stats.median, 3.5)
self.assertEqual(error_stats.reference_stats.mean, 3.5)
self.assertEqual(error_stats.reference_stats.max, 5)
self.assertEqual(error_stats.reference_stats.min, 2)

# Check that all error values are None when shapes differ
self.assertIsNone(error_stats.error_l2_norm)
self.assertIsNone(error_stats.error_mae)
self.assertIsNone(error_stats.error_max)
self.assertIsNone(error_stats.error_msd)
self.assertIsNone(error_stats.sqnr)
46 changes: 46 additions & 0 deletions backends/test/suite/reporting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import csv
from collections import Counter
from dataclasses import dataclass
from datetime import timedelta
from enum import IntEnum
from functools import reduce
from typing import TextIO

from executorch.backends.test.harness.error_statistics import ErrorStatistics


class TestResult(IntEnum):
"""Represents the result of a test case run, indicating success or a specific failure reason."""
Expand Down Expand Up @@ -100,6 +103,18 @@ class TestCaseSummary:
error: Exception | None
""" The Python exception object, if any. """

tensor_error_statistics: list[ErrorStatistics]
"""
Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
a single output tensor.
"""

quantize_time: timedelta | None = None
""" The total runtime of the quantization stage, or none, if the test did not run the quantize stage. """

lower_time: timedelta | None = None
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """


class TestSessionState:
test_case_summaries: list[TestCaseSummary]
Expand Down Expand Up @@ -182,6 +197,8 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
"Backend",
"Flow",
"Result",
"Quantize Time (s)",
"Lowering Time (s)",
]

# Tests can have custom parameters. We'll want to report them here, so we need
Expand All @@ -197,6 +214,21 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
)
field_names += (s.capitalize() for s in param_names)

# Add tensor error statistic field names for each output index.
max_outputs = max(
len(s.tensor_error_statistics) for s in summary.test_case_summaries
)
for i in range(max_outputs):
field_names.extend(
[
f"Output {i} Error Max",
f"Output {i} Error MAE",
f"Output {i} Error MSD",
f"Output {i} Error L2",
f"Output {i} SQNR",
]
)

writer = csv.DictWriter(output, field_names)
writer.writeheader()

Expand All @@ -207,7 +239,21 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
"Backend": record.backend,
"Flow": record.flow,
"Result": record.result.display_name(),
"Quantize Time (s)": (
record.quantize_time.total_seconds() if record.quantize_time else None
),
"Lowering Time (s)": (
record.lower_time.total_seconds() if record.lower_time else None
),
}
if record.params is not None:
row.update({k.capitalize(): v for k, v in record.params.items()})

for output_idx, error_stats in enumerate(record.tensor_error_statistics):
row[f"Output {output_idx} Error Max"] = error_stats.error_max
row[f"Output {output_idx} Error MAE"] = error_stats.error_mae
row[f"Output {output_idx} Error MSD"] = error_stats.error_msd
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
row[f"Output {output_idx} SQNR"] = error_stats.sqnr

writer.writerow(row)
Loading
Loading