Skip to content

Commit 451eae3

Browse files
committed
[Backend Tester] Add tensor error statistic reporting
ghstack-source-id: 63819cb ghstack-comment-id: 3112003831 Pull-Request: pytorch#12809
1 parent a006ab2 commit 451eae3

File tree

5 files changed

+227
-4
lines changed

5 files changed

+227
-4
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
from torch.ao.ns.fx.utils import compute_sqnr
5+
6+
7+
@dataclass
8+
class TensorStatistics:
9+
"""Contains summary statistics for a tensor."""
10+
11+
shape: torch.Size
12+
""" The shape of the tensor. """
13+
14+
numel: int
15+
""" The number of elements in the tensor. """
16+
17+
median: float
18+
""" The median of the tensor. """
19+
20+
mean: float
21+
""" The mean of the tensor. """
22+
23+
max: torch.types.Number
24+
""" The maximum element of the tensor. """
25+
26+
min: torch.types.Number
27+
""" The minimum element of the tensor. """
28+
29+
@classmethod
30+
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
31+
"""Creates a TensorStatistics object from a tensor."""
32+
flattened = torch.flatten(tensor)
33+
return cls(
34+
shape=tensor.shape,
35+
numel=tensor.numel(),
36+
median=torch.quantile(flattened, q=0.5).item(),
37+
mean=flattened.mean().item(),
38+
max=flattened.max().item(),
39+
min=flattened.min().item(),
40+
)
41+
42+
43+
@dataclass
44+
class ErrorStatistics:
45+
"""Contains statistics derived from the difference of two tensors."""
46+
47+
reference_stats: TensorStatistics
48+
""" Statistics for the reference tensor. """
49+
50+
actual_stats: TensorStatistics
51+
""" Statistics for the actual tensor. """
52+
53+
error_l2_norm: float | None
54+
""" The L2 norm of the error between the actual and reference tensor. """
55+
56+
error_mae: float | None
57+
""" The mean absolute error between the actual and reference tensor. """
58+
59+
error_max: float | None
60+
""" The maximum absolute elementwise error between the actual and reference tensor. """
61+
62+
error_msd: float | None
63+
""" The mean signed deviation between the actual and reference tensor. """
64+
65+
sqnr: float | None
66+
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """
67+
68+
@classmethod
69+
def from_tensors(
70+
cls, actual: torch.Tensor, reference: torch.Tensor
71+
) -> "ErrorStatistics":
72+
"""Creates an ErrorStatistics object from two tensors."""
73+
actual = actual.to(torch.float64)
74+
reference = reference.to(torch.float64)
75+
76+
if actual.shape != reference.shape:
77+
return cls(
78+
reference_stats=TensorStatistics.from_tensor(reference),
79+
actual_stats=TensorStatistics.from_tensor(actual),
80+
error_l2_norm=None,
81+
error_mae=None,
82+
error_max=None,
83+
error_msd=None,
84+
sqnr=None,
85+
)
86+
87+
error = actual - reference
88+
flat_error = torch.flatten(error)
89+
90+
return cls(
91+
reference_stats=TensorStatistics.from_tensor(reference),
92+
actual_stats=TensorStatistics.from_tensor(actual),
93+
error_l2_norm=torch.linalg.norm(flat_error).item(),
94+
error_mae=torch.mean(torch.abs(flat_error)).item(),
95+
error_max=torch.max(torch.abs(flat_error)).item(),
96+
error_msd=torch.mean(flat_error).item(),
97+
# Torch sqnr implementation requires float32 due to decorator logic
98+
sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(),
99+
)

backends/test/harness/tester.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
78
from executorch.backends.test.harness.stages import (
89
Export,
910
Partition,
@@ -302,20 +303,23 @@ def run_method_and_compare_outputs(
302303
atol=1e-03,
303304
rtol=1e-03,
304305
qtol=0,
306+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
305307
):
306308
number_of_runs = 1 if inputs is not None else num_runs
307309
reference_stage = self.stages[StageType.EXPORT]
308310

309311
stage = stage or self.cur
310312

311-
print(f"Comparing Stage {stage} with Stage {reference_stage}")
312-
for run_iteration in range(number_of_runs):
313+
for _ in range(number_of_runs):
313314
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
315+
<<<<<<< HEAD
314316
input_shapes = [
315317
generated_input.shape if hasattr(generated_input, "shape") else None
316318
for generated_input in inputs_to_run
317319
]
318320
print(f"Run {run_iteration} with input shapes: {input_shapes}")
321+
=======
322+
>>>>>>> 6e4c57717 ([Backend Tester] Add tensor error statistic reporting)
319323

320324
# Reference output (and quantization scale)
321325
(
@@ -328,13 +332,25 @@ def run_method_and_compare_outputs(
328332
# Output from running artifact at stage
329333
stage_output = self.stages[stage].run_artifact(inputs_to_run)
330334
self._compare_outputs(
331-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
335+
reference_output,
336+
stage_output,
337+
quantization_scale,
338+
atol,
339+
rtol,
340+
qtol,
341+
statistics_callback,
332342
)
333343

334344
return self
335345

336346
@staticmethod
337-
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
347+
def _assert_outputs_equal(
348+
model_output,
349+
ref_output,
350+
atol=1e-03,
351+
rtol=1e-03,
352+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
353+
):
338354
"""
339355
Helper testing function that asserts that the model output and the reference output
340356
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -349,6 +365,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
349365
for i in range(len(model_output)):
350366
model = model_output[i]
351367
ref = ref_output[i]
368+
369+
error_stats = ErrorStatistics.from_tensors(model, ref)
370+
if statistics_callback is not None:
371+
statistics_callback(error_stats)
372+
352373
assert (
353374
ref.shape == model.shape
354375
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
@@ -386,6 +407,7 @@ def _compare_outputs(
386407
atol=1e-03,
387408
rtol=1e-03,
388409
qtol=0,
410+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
389411
):
390412
"""
391413
Compares the original of the original nn module with the output of the generated artifact.
@@ -408,6 +430,7 @@ def _compare_outputs(
408430
reference_output,
409431
atol=atol,
410432
rtol=rtol,
433+
statistics_callback=statistics_callback,
411434
)
412435

413436
@staticmethod
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
5+
6+
7+
class ErrorStatisticsTests(unittest.TestCase):
8+
def test_error_stats_simple(self):
9+
tensor1 = torch.tensor([1, 2, 3, 4])
10+
tensor2 = torch.tensor([2, 2, 2, 5])
11+
12+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
13+
14+
# Check actual tensor statistics
15+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
16+
self.assertEqual(error_stats.actual_stats.numel, 4)
17+
self.assertEqual(error_stats.actual_stats.median, 2.5)
18+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
19+
self.assertEqual(error_stats.actual_stats.max, 4)
20+
self.assertEqual(error_stats.actual_stats.min, 1)
21+
22+
# Check reference tensor statistics
23+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([4]))
24+
self.assertEqual(error_stats.reference_stats.numel, 4)
25+
self.assertEqual(error_stats.reference_stats.median, 2.0)
26+
self.assertEqual(error_stats.reference_stats.mean, 2.75)
27+
self.assertEqual(error_stats.reference_stats.max, 5)
28+
self.assertEqual(error_stats.reference_stats.min, 2)
29+
30+
# Check error statistics
31+
self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3)
32+
self.assertEqual(error_stats.error_mae, 0.75)
33+
self.assertEqual(error_stats.error_max, 1.0)
34+
self.assertEqual(error_stats.error_msd, -0.25)
35+
self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3)
36+
37+
def test_error_stats_different_shapes(self):
38+
# Create tensors with different shapes
39+
tensor1 = torch.tensor([1, 2, 3, 4])
40+
tensor2 = torch.tensor([[2, 3], [4, 5]])
41+
42+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
43+
44+
# Check actual tensor statistics
45+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
46+
self.assertEqual(error_stats.actual_stats.numel, 4)
47+
self.assertEqual(error_stats.actual_stats.median, 2.5)
48+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
49+
self.assertEqual(error_stats.actual_stats.max, 4)
50+
self.assertEqual(error_stats.actual_stats.min, 1)
51+
52+
# Check reference tensor statistics
53+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2]))
54+
self.assertEqual(error_stats.reference_stats.numel, 4)
55+
self.assertEqual(error_stats.reference_stats.median, 3.5)
56+
self.assertEqual(error_stats.reference_stats.mean, 3.5)
57+
self.assertEqual(error_stats.reference_stats.max, 5)
58+
self.assertEqual(error_stats.reference_stats.min, 2)
59+
60+
# Check that all error values are None when shapes differ
61+
self.assertIsNone(error_stats.error_l2_norm)
62+
self.assertIsNone(error_stats.error_mae)
63+
self.assertIsNone(error_stats.error_max)
64+
self.assertIsNone(error_stats.error_msd)
65+
self.assertIsNone(error_stats.sqnr)

backends/test/suite/reporting.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from functools import reduce
66
from typing import TextIO
77

8+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
9+
810

911
class TestResult(IntEnum):
1012
"""Represents the result of a test case run, indicating success or a specific failure reason."""
@@ -100,6 +102,12 @@ class TestCaseSummary:
100102
error: Exception | None
101103
""" The Python exception object, if any. """
102104

105+
tensor_error_statistics: list[ErrorStatistics]
106+
"""
107+
Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
108+
a single output tensor.
109+
"""
110+
103111

104112
class TestSessionState:
105113
test_case_summaries: list[TestCaseSummary]
@@ -197,6 +205,21 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
197205
)
198206
field_names += (s.capitalize() for s in param_names)
199207

208+
# Add tensor error statistic field names for each output index.
209+
max_outputs = max(
210+
len(s.tensor_error_statistics) for s in summary.test_case_summaries
211+
)
212+
for i in range(max_outputs):
213+
field_names.extend(
214+
[
215+
f"Output {i} Error Max",
216+
f"Output {i} Error MAE",
217+
f"Output {i} Error MSD",
218+
f"Output {i} Error L2",
219+
f"Output {i} SQNR",
220+
]
221+
)
222+
200223
writer = csv.DictWriter(output, field_names)
201224
writer.writeheader()
202225

@@ -210,4 +233,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
210233
}
211234
if record.params is not None:
212235
row.update({k.capitalize(): v for k, v in record.params.items()})
236+
237+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
238+
row[f"Output {output_idx} Error Max"] = error_stats.error_max
239+
row[f"Output {output_idx} Error MAE"] = error_stats.error_mae
240+
row[f"Output {output_idx} Error MSD"] = error_stats.error_msd
241+
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
242+
row[f"Output {output_idx} SQNR"] = error_stats.sqnr
243+
213244
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99

10+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
1011
from executorch.backends.test.harness.stages import StageType
1112
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
1213
from executorch.backends.test.suite.flow import TestFlow
@@ -42,6 +43,8 @@ def run_test( # noqa: C901
4243
and reporting.
4344
"""
4445

46+
error_statistics: list[ErrorStatistics] = []
47+
4548
# Helper method to construct the summary.
4649
def build_result(
4750
result: TestResult, error: Exception | None = None
@@ -54,6 +57,7 @@ def build_result(
5457
params=params,
5558
result=result,
5659
error=error,
60+
tensor_error_statistics=error_statistics,
5761
)
5862

5963
# Ensure the model can run in eager mode.
@@ -109,6 +113,7 @@ def build_result(
109113
try:
110114
tester.run_method_and_compare_outputs(
111115
inputs=None if generate_random_test_inputs else inputs
116+
statistics_callback=lambda stats: error_statistics.append(stats)
112117
)
113118
except AssertionError as e:
114119
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)

0 commit comments

Comments
 (0)