Skip to content

Commit 010b800

Browse files
authored
[Backend Tester] Add tensor error statistic reporting (#12809)
Report various error statistics for the test outputs, including SQNR, mean absolute error (MAE), and L2 norm. These are saved in the detail report per test case. As an example, here is the output from Core ML running MobileNet V2 (roughly formatted from csv -> sheets -> markdown): ``` Output 0 Error Max Output 0 Error MAE Output 0 Error MSD Output 0 Error L2 Output 0 SQNR 0.0005887411535 0.0001199183663 2.32E-06 0.004750485188 41.28595734 ```
1 parent 71c9a4f commit 010b800

File tree

5 files changed

+225
-10
lines changed

5 files changed

+225
-10
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
from torch.ao.ns.fx.utils import compute_sqnr
5+
6+
7+
@dataclass
8+
class TensorStatistics:
9+
"""Contains summary statistics for a tensor."""
10+
11+
shape: torch.Size
12+
""" The shape of the tensor. """
13+
14+
numel: int
15+
""" The number of elements in the tensor. """
16+
17+
median: float
18+
""" The median of the tensor. """
19+
20+
mean: float
21+
""" The mean of the tensor. """
22+
23+
max: torch.types.Number
24+
""" The maximum element of the tensor. """
25+
26+
min: torch.types.Number
27+
""" The minimum element of the tensor. """
28+
29+
@classmethod
30+
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
31+
"""Creates a TensorStatistics object from a tensor."""
32+
flattened = torch.flatten(tensor)
33+
return cls(
34+
shape=tensor.shape,
35+
numel=tensor.numel(),
36+
median=torch.quantile(flattened, q=0.5).item(),
37+
mean=flattened.mean().item(),
38+
max=flattened.max().item(),
39+
min=flattened.min().item(),
40+
)
41+
42+
43+
@dataclass
44+
class ErrorStatistics:
45+
"""Contains statistics derived from the difference of two tensors."""
46+
47+
reference_stats: TensorStatistics
48+
""" Statistics for the reference tensor. """
49+
50+
actual_stats: TensorStatistics
51+
""" Statistics for the actual tensor. """
52+
53+
error_l2_norm: float | None
54+
""" The L2 norm of the error between the actual and reference tensor. """
55+
56+
error_mae: float | None
57+
""" The mean absolute error between the actual and reference tensor. """
58+
59+
error_max: float | None
60+
""" The maximum absolute elementwise error between the actual and reference tensor. """
61+
62+
error_msd: float | None
63+
""" The mean signed deviation between the actual and reference tensor. """
64+
65+
sqnr: float | None
66+
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """
67+
68+
@classmethod
69+
def from_tensors(
70+
cls, actual: torch.Tensor, reference: torch.Tensor
71+
) -> "ErrorStatistics":
72+
"""Creates an ErrorStatistics object from two tensors."""
73+
actual = actual.to(torch.float64)
74+
reference = reference.to(torch.float64)
75+
76+
if actual.shape != reference.shape:
77+
return cls(
78+
reference_stats=TensorStatistics.from_tensor(reference),
79+
actual_stats=TensorStatistics.from_tensor(actual),
80+
error_l2_norm=None,
81+
error_mae=None,
82+
error_max=None,
83+
error_msd=None,
84+
sqnr=None,
85+
)
86+
87+
error = actual - reference
88+
flat_error = torch.flatten(error)
89+
90+
return cls(
91+
reference_stats=TensorStatistics.from_tensor(reference),
92+
actual_stats=TensorStatistics.from_tensor(actual),
93+
error_l2_norm=torch.linalg.norm(flat_error).item(),
94+
error_mae=torch.mean(torch.abs(flat_error)).item(),
95+
error_max=torch.max(torch.abs(flat_error)).item(),
96+
error_msd=torch.mean(flat_error).item(),
97+
# Torch sqnr implementation requires float32 due to decorator logic
98+
sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(),
99+
)

backends/test/harness/tester.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
78
from executorch.backends.test.harness.stages import (
89
Export,
910
Partition,
@@ -302,20 +303,15 @@ def run_method_and_compare_outputs(
302303
atol=1e-03,
303304
rtol=1e-03,
304305
qtol=0,
306+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
305307
):
306308
number_of_runs = 1 if inputs is not None else num_runs
307309
reference_stage = self.stages[StageType.EXPORT]
308310

309311
stage = stage or self.cur
310312

311-
print(f"Comparing Stage {stage} with Stage {reference_stage}")
312-
for run_iteration in range(number_of_runs):
313+
for _ in range(number_of_runs):
313314
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
314-
input_shapes = [
315-
generated_input.shape if hasattr(generated_input, "shape") else None
316-
for generated_input in inputs_to_run
317-
]
318-
print(f"Run {run_iteration} with input shapes: {input_shapes}")
319315

320316
# Reference output (and quantization scale)
321317
(
@@ -328,13 +324,25 @@ def run_method_and_compare_outputs(
328324
# Output from running artifact at stage
329325
stage_output = self.stages[stage].run_artifact(inputs_to_run)
330326
self._compare_outputs(
331-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
327+
reference_output,
328+
stage_output,
329+
quantization_scale,
330+
atol,
331+
rtol,
332+
qtol,
333+
statistics_callback,
332334
)
333335

334336
return self
335337

336338
@staticmethod
337-
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
339+
def _assert_outputs_equal(
340+
model_output,
341+
ref_output,
342+
atol=1e-03,
343+
rtol=1e-03,
344+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
345+
):
338346
"""
339347
Helper testing function that asserts that the model output and the reference output
340348
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -349,6 +357,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
349357
for i in range(len(model_output)):
350358
model = model_output[i]
351359
ref = ref_output[i]
360+
361+
error_stats = ErrorStatistics.from_tensors(model, ref)
362+
if statistics_callback is not None:
363+
statistics_callback(error_stats)
364+
352365
assert (
353366
ref.shape == model.shape
354367
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
@@ -386,6 +399,7 @@ def _compare_outputs(
386399
atol=1e-03,
387400
rtol=1e-03,
388401
qtol=0,
402+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
389403
):
390404
"""
391405
Compares the original of the original nn module with the output of the generated artifact.
@@ -408,6 +422,7 @@ def _compare_outputs(
408422
reference_output,
409423
atol=atol,
410424
rtol=rtol,
425+
statistics_callback=statistics_callback,
411426
)
412427

413428
@staticmethod
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
5+
6+
7+
class ErrorStatisticsTests(unittest.TestCase):
8+
def test_error_stats_simple(self):
9+
tensor1 = torch.tensor([1, 2, 3, 4])
10+
tensor2 = torch.tensor([2, 2, 2, 5])
11+
12+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
13+
14+
# Check actual tensor statistics
15+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
16+
self.assertEqual(error_stats.actual_stats.numel, 4)
17+
self.assertEqual(error_stats.actual_stats.median, 2.5)
18+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
19+
self.assertEqual(error_stats.actual_stats.max, 4)
20+
self.assertEqual(error_stats.actual_stats.min, 1)
21+
22+
# Check reference tensor statistics
23+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([4]))
24+
self.assertEqual(error_stats.reference_stats.numel, 4)
25+
self.assertEqual(error_stats.reference_stats.median, 2.0)
26+
self.assertEqual(error_stats.reference_stats.mean, 2.75)
27+
self.assertEqual(error_stats.reference_stats.max, 5)
28+
self.assertEqual(error_stats.reference_stats.min, 2)
29+
30+
# Check error statistics
31+
self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3)
32+
self.assertEqual(error_stats.error_mae, 0.75)
33+
self.assertEqual(error_stats.error_max, 1.0)
34+
self.assertEqual(error_stats.error_msd, -0.25)
35+
self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3)
36+
37+
def test_error_stats_different_shapes(self):
38+
# Create tensors with different shapes
39+
tensor1 = torch.tensor([1, 2, 3, 4])
40+
tensor2 = torch.tensor([[2, 3], [4, 5]])
41+
42+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
43+
44+
# Check actual tensor statistics
45+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
46+
self.assertEqual(error_stats.actual_stats.numel, 4)
47+
self.assertEqual(error_stats.actual_stats.median, 2.5)
48+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
49+
self.assertEqual(error_stats.actual_stats.max, 4)
50+
self.assertEqual(error_stats.actual_stats.min, 1)
51+
52+
# Check reference tensor statistics
53+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2]))
54+
self.assertEqual(error_stats.reference_stats.numel, 4)
55+
self.assertEqual(error_stats.reference_stats.median, 3.5)
56+
self.assertEqual(error_stats.reference_stats.mean, 3.5)
57+
self.assertEqual(error_stats.reference_stats.max, 5)
58+
self.assertEqual(error_stats.reference_stats.min, 2)
59+
60+
# Check that all error values are None when shapes differ
61+
self.assertIsNone(error_stats.error_l2_norm)
62+
self.assertIsNone(error_stats.error_mae)
63+
self.assertIsNone(error_stats.error_max)
64+
self.assertIsNone(error_stats.error_msd)
65+
self.assertIsNone(error_stats.sqnr)

backends/test/suite/reporting.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from functools import reduce
66
from typing import TextIO
77

8+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
9+
810

911
class TestResult(IntEnum):
1012
"""Represents the result of a test case run, indicating success or a specific failure reason."""
@@ -100,6 +102,12 @@ class TestCaseSummary:
100102
error: Exception | None
101103
""" The Python exception object, if any. """
102104

105+
tensor_error_statistics: list[ErrorStatistics]
106+
"""
107+
Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
108+
a single output tensor.
109+
"""
110+
103111

104112
class TestSessionState:
105113
test_case_summaries: list[TestCaseSummary]
@@ -197,6 +205,21 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
197205
)
198206
field_names += (s.capitalize() for s in param_names)
199207

208+
# Add tensor error statistic field names for each output index.
209+
max_outputs = max(
210+
len(s.tensor_error_statistics) for s in summary.test_case_summaries
211+
)
212+
for i in range(max_outputs):
213+
field_names.extend(
214+
[
215+
f"Output {i} Error Max",
216+
f"Output {i} Error MAE",
217+
f"Output {i} Error MSD",
218+
f"Output {i} Error L2",
219+
f"Output {i} SQNR",
220+
]
221+
)
222+
200223
writer = csv.DictWriter(output, field_names)
201224
writer.writeheader()
202225

@@ -210,4 +233,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
210233
}
211234
if record.params is not None:
212235
row.update({k.capitalize(): v for k, v in record.params.items()})
236+
237+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
238+
row[f"Output {output_idx} Error Max"] = error_stats.error_max
239+
row[f"Output {output_idx} Error MAE"] = error_stats.error_mae
240+
row[f"Output {output_idx} Error MSD"] = error_stats.error_msd
241+
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
242+
row[f"Output {output_idx} SQNR"] = error_stats.sqnr
243+
213244
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99

10+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
1011
from executorch.backends.test.harness.stages import StageType
1112
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
1213
from executorch.backends.test.suite.flow import TestFlow
@@ -42,6 +43,8 @@ def run_test( # noqa: C901
4243
and reporting.
4344
"""
4445

46+
error_statistics: list[ErrorStatistics] = []
47+
4548
# Helper method to construct the summary.
4649
def build_result(
4750
result: TestResult, error: Exception | None = None
@@ -54,6 +57,7 @@ def build_result(
5457
params=params,
5558
result=result,
5659
error=error,
60+
tensor_error_statistics=error_statistics,
5761
)
5862

5963
# Ensure the model can run in eager mode.
@@ -108,7 +112,8 @@ def build_result(
108112
# AssertionErrors to catch output mismatches, but this might catch more than that.
109113
try:
110114
tester.run_method_and_compare_outputs(
111-
inputs=None if generate_random_test_inputs else inputs
115+
inputs=None if generate_random_test_inputs else inputs,
116+
statistics_callback=lambda stats: error_statistics.append(stats),
112117
)
113118
except AssertionError as e:
114119
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)

0 commit comments

Comments
 (0)