Skip to content

Commit bd04bcd

Browse files
committed
[Backend Tester] Add tensor error statistic reporting
ghstack-source-id: 63819cb ghstack-comment-id: 3112003831 Pull-Request: #12809
1 parent d9f3ca0 commit bd04bcd

File tree

5 files changed

+226
-7
lines changed

5 files changed

+226
-7
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
from torch.ao.ns.fx.utils import compute_sqnr
5+
6+
7+
@dataclass
8+
class TensorStatistics:
9+
"""Contains summary statistics for a tensor."""
10+
11+
shape: torch.Size
12+
""" The shape of the tensor. """
13+
14+
numel: int
15+
""" The number of elements in the tensor. """
16+
17+
median: float
18+
""" The median of the tensor. """
19+
20+
mean: float
21+
""" The mean of the tensor. """
22+
23+
max: torch.types.Number
24+
""" The maximum element of the tensor. """
25+
26+
min: torch.types.Number
27+
""" The minimum element of the tensor. """
28+
29+
@classmethod
30+
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
31+
"""Creates a TensorStatistics object from a tensor."""
32+
flattened = torch.flatten(tensor)
33+
return cls(
34+
shape=tensor.shape,
35+
numel=tensor.numel(),
36+
median=torch.quantile(flattened, q=0.5).item(),
37+
mean=flattened.mean().item(),
38+
max=flattened.max().item(),
39+
min=flattened.min().item(),
40+
)
41+
42+
43+
@dataclass
44+
class ErrorStatistics:
45+
"""Contains statistics derived from the difference of two tensors."""
46+
47+
reference_stats: TensorStatistics
48+
""" Statistics for the reference tensor. """
49+
50+
actual_stats: TensorStatistics
51+
""" Statistics for the actual tensor. """
52+
53+
error_l2_norm: float | None
54+
""" The L2 norm of the error between the actual and reference tensor. """
55+
56+
error_mae: float | None
57+
""" The mean absolute error between the actual and reference tensor. """
58+
59+
error_max: float | None
60+
""" The maximum absolute elementwise error between the actual and reference tensor. """
61+
62+
error_msd: float | None
63+
""" The mean signed deviation between the actual and reference tensor. """
64+
65+
sqnr: float | None
66+
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """
67+
68+
@classmethod
69+
def from_tensors(
70+
cls, actual: torch.Tensor, reference: torch.Tensor
71+
) -> "ErrorStatistics":
72+
"""Creates an ErrorStatistics object from two tensors."""
73+
actual = actual.to(torch.float64)
74+
reference = reference.to(torch.float64)
75+
76+
if actual.shape != reference.shape:
77+
return cls(
78+
reference_stats=TensorStatistics.from_tensor(reference),
79+
actual_stats=TensorStatistics.from_tensor(actual),
80+
error_l2_norm=None,
81+
error_mae=None,
82+
error_max=None,
83+
error_msd=None,
84+
sqnr=None,
85+
)
86+
87+
error = actual - reference
88+
flat_error = torch.flatten(error)
89+
90+
return cls(
91+
reference_stats=TensorStatistics.from_tensor(reference),
92+
actual_stats=TensorStatistics.from_tensor(actual),
93+
error_l2_norm=torch.linalg.norm(flat_error).item(),
94+
error_mae=torch.mean(torch.abs(flat_error)).item(),
95+
error_max=torch.max(torch.abs(flat_error)).item(),
96+
error_msd=torch.mean(flat_error).item(),
97+
# Torch sqnr implementation requires float32 due to decorator logic
98+
sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(),
99+
)

backends/test/harness/tester.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
78
from executorch.backends.test.harness.stages import (
89
Export,
910
Partition,
@@ -302,17 +303,15 @@ def run_method_and_compare_outputs(
302303
atol=1e-03,
303304
rtol=1e-03,
304305
qtol=0,
306+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
305307
):
306308
number_of_runs = 1 if inputs is not None else num_runs
307309
reference_stage = self.stages[StageType.EXPORT]
308310

309311
stage = stage or self.cur
310312

311-
print(f"Comparing Stage {stage} with Stage {reference_stage}")
312-
for run_iteration in range(number_of_runs):
313+
for _ in range(number_of_runs):
313314
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
314-
input_shapes = [generated_input.shape for generated_input in inputs_to_run]
315-
print(f"Run {run_iteration} with input shapes: {input_shapes}")
316315

317316
# Reference output (and quantization scale)
318317
(
@@ -325,13 +324,25 @@ def run_method_and_compare_outputs(
325324
# Output from running artifact at stage
326325
stage_output = self.stages[stage].run_artifact(inputs_to_run)
327326
self._compare_outputs(
328-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
327+
reference_output,
328+
stage_output,
329+
quantization_scale,
330+
atol,
331+
rtol,
332+
qtol,
333+
statistics_callback,
329334
)
330335

331336
return self
332337

333338
@staticmethod
334-
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
339+
def _assert_outputs_equal(
340+
model_output,
341+
ref_output,
342+
atol=1e-03,
343+
rtol=1e-03,
344+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
345+
):
335346
"""
336347
Helper testing function that asserts that the model output and the reference output
337348
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -346,6 +357,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
346357
for i in range(len(model_output)):
347358
model = model_output[i]
348359
ref = ref_output[i]
360+
361+
error_stats = ErrorStatistics.from_tensors(model, ref)
362+
if statistics_callback is not None:
363+
statistics_callback(error_stats)
364+
349365
assert (
350366
ref.shape == model.shape
351367
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
@@ -383,6 +399,7 @@ def _compare_outputs(
383399
atol=1e-03,
384400
rtol=1e-03,
385401
qtol=0,
402+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
386403
):
387404
"""
388405
Compares the original of the original nn module with the output of the generated artifact.
@@ -405,6 +422,7 @@ def _compare_outputs(
405422
reference_output,
406423
atol=atol,
407424
rtol=rtol,
425+
statistics_callback=statistics_callback,
408426
)
409427

410428
@staticmethod
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
5+
6+
7+
class ErrorStatisticsTests(unittest.TestCase):
8+
def test_error_stats_simple(self):
9+
tensor1 = torch.tensor([1, 2, 3, 4])
10+
tensor2 = torch.tensor([2, 2, 2, 5])
11+
12+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
13+
14+
# Check actual tensor statistics
15+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
16+
self.assertEqual(error_stats.actual_stats.numel, 4)
17+
self.assertEqual(error_stats.actual_stats.median, 2.5)
18+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
19+
self.assertEqual(error_stats.actual_stats.max, 4)
20+
self.assertEqual(error_stats.actual_stats.min, 1)
21+
22+
# Check reference tensor statistics
23+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([4]))
24+
self.assertEqual(error_stats.reference_stats.numel, 4)
25+
self.assertEqual(error_stats.reference_stats.median, 2.0)
26+
self.assertEqual(error_stats.reference_stats.mean, 2.75)
27+
self.assertEqual(error_stats.reference_stats.max, 5)
28+
self.assertEqual(error_stats.reference_stats.min, 2)
29+
30+
# Check error statistics
31+
self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3)
32+
self.assertEqual(error_stats.error_mae, 0.75)
33+
self.assertEqual(error_stats.error_max, 1.0)
34+
self.assertEqual(error_stats.error_msd, -0.25)
35+
self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3)
36+
37+
def test_error_stats_different_shapes(self):
38+
# Create tensors with different shapes
39+
tensor1 = torch.tensor([1, 2, 3, 4])
40+
tensor2 = torch.tensor([[2, 3], [4, 5]])
41+
42+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
43+
44+
# Check actual tensor statistics
45+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
46+
self.assertEqual(error_stats.actual_stats.numel, 4)
47+
self.assertEqual(error_stats.actual_stats.median, 2.5)
48+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
49+
self.assertEqual(error_stats.actual_stats.max, 4)
50+
self.assertEqual(error_stats.actual_stats.min, 1)
51+
52+
# Check reference tensor statistics
53+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2]))
54+
self.assertEqual(error_stats.reference_stats.numel, 4)
55+
self.assertEqual(error_stats.reference_stats.median, 3.5)
56+
self.assertEqual(error_stats.reference_stats.mean, 3.5)
57+
self.assertEqual(error_stats.reference_stats.max, 5)
58+
self.assertEqual(error_stats.reference_stats.min, 2)
59+
60+
# Check that all error values are None when shapes differ
61+
self.assertIsNone(error_stats.error_l2_norm)
62+
self.assertIsNone(error_stats.error_mae)
63+
self.assertIsNone(error_stats.error_max)
64+
self.assertIsNone(error_stats.error_msd)
65+
self.assertIsNone(error_stats.sqnr)

backends/test/suite/reporting.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from functools import reduce
66
from typing import TextIO
77

8+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
9+
810

911
class TestResult(IntEnum):
1012
"""Represents the result of a test case run, indicating success or a specific failure reason."""
@@ -100,6 +102,12 @@ class TestCaseSummary:
100102
error: Exception | None
101103
""" The Python exception object, if any. """
102104

105+
tensor_error_statistics: list[ErrorStatistics]
106+
"""
107+
Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
108+
a single output tensor.
109+
"""
110+
103111

104112
class TestSessionState:
105113
test_case_summaries: list[TestCaseSummary]
@@ -197,6 +205,21 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
197205
)
198206
field_names += (s.capitalize() for s in param_names)
199207

208+
# Add tensor error statistic field names for each output index.
209+
max_outputs = max(
210+
len(s.tensor_error_statistics) for s in summary.test_case_summaries
211+
)
212+
for i in range(max_outputs):
213+
field_names.extend(
214+
[
215+
f"Output {i} Error Max",
216+
f"Output {i} Error MAE",
217+
f"Output {i} Error MSD",
218+
f"Output {i} Error L2",
219+
f"Output {i} SQNR",
220+
]
221+
)
222+
200223
writer = csv.DictWriter(output, field_names)
201224
writer.writeheader()
202225

@@ -210,4 +233,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
210233
}
211234
if record.params is not None:
212235
row.update({k.capitalize(): v for k, v in record.params.items()})
236+
237+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
238+
row[f"Output {output_idx} Error Max"] = error_stats.error_max
239+
row[f"Output {output_idx} Error MAE"] = error_stats.error_mae
240+
row[f"Output {output_idx} Error MSD"] = error_stats.error_msd
241+
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
242+
row[f"Output {output_idx} SQNR"] = error_stats.sqnr
243+
213244
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99

10+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
1011
from executorch.backends.test.harness.stages import StageType
1112
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
1213
from executorch.backends.test.suite.flow import TestFlow
@@ -41,6 +42,8 @@ def run_test( # noqa: C901
4142
and reporting.
4243
"""
4344

45+
error_statistics: list[ErrorStatistics] = []
46+
4447
# Helper method to construct the summary.
4548
def build_result(
4649
result: TestResult, error: Exception | None = None
@@ -53,6 +56,7 @@ def build_result(
5356
params=params,
5457
result=result,
5558
error=error,
59+
tensor_error_statistics=error_statistics,
5660
)
5761

5862
# Ensure the model can run in eager mode.
@@ -106,7 +110,9 @@ def build_result(
106110
# the cause of a failure in run_method_and_compare_outputs. We can look for
107111
# AssertionErrors to catch output mismatches, but this might catch more than that.
108112
try:
109-
tester.run_method_and_compare_outputs()
113+
tester.run_method_and_compare_outputs(
114+
statistics_callback=lambda stats: error_statistics.append(stats)
115+
)
110116
except AssertionError as e:
111117
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)
112118
except Exception as e:

0 commit comments

Comments
 (0)