Skip to content

Commit ac768f7

Browse files
committed
[Backend Tester] Add tensor error statistic reporting
ghstack-source-id: 9eb6b0a ghstack-comment-id: 3112003831 Pull-Request: #12809
1 parent 6d110fd commit ac768f7

File tree

4 files changed

+144
-9
lines changed

4 files changed

+144
-9
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from dataclasses import dataclass
2+
from torch.ao.ns.fx.utils import compute_sqnr
3+
4+
import torch
5+
6+
@dataclass
7+
class TensorStatistics:
8+
""" Contains summary statistics for a tensor. """
9+
10+
shape: torch.Size
11+
""" The shape of the tensor. """
12+
13+
numel: int
14+
""" The number of elements in the tensor. """
15+
16+
median: float
17+
""" The median of the tensor. """
18+
19+
mean: float
20+
""" The mean of the tensor. """
21+
22+
max: torch.types.Number
23+
""" The maximum element of the tensor. """
24+
25+
min: torch.types.Number
26+
""" The minimum element of the tensor. """
27+
28+
@classmethod
29+
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
30+
""" Creates a TensorStatistics object from a tensor. """
31+
flattened = torch.flatten(tensor)
32+
return cls(
33+
shape=tensor.shape,
34+
numel=tensor.numel(),
35+
median=flattened.median().item(),
36+
mean=flattened.mean().item(),
37+
max=flattened.max().item(),
38+
min=flattened.min().item(),
39+
)
40+
41+
@dataclass
42+
class ErrorStatistics:
43+
""" Contains statistics derived from the difference of two tensors. """
44+
45+
reference_stats: TensorStatistics
46+
""" Statistics for the reference tensor. """
47+
48+
actual_stats: TensorStatistics
49+
""" Statistics for the actual tensor. """
50+
51+
error_l2_norm: float | None
52+
""" The L2 norm of the error between the actual and reference tensor. """
53+
54+
error_mae: float | None
55+
""" The mean absolute error between the actual and reference tensor. """
56+
57+
error_max: float | None
58+
""" The maximum absolute elementwise error between the actual and reference tensor. """
59+
60+
error_msd: float | None
61+
""" The mean signed deviation between the actual and reference tensor. """
62+
63+
sqnr: float | None
64+
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """
65+
66+
@classmethod
67+
def from_tensors(cls, actual: torch.Tensor, reference: torch.Tensor) -> "ErrorStatistics":
68+
""" Creates an ErrorStatistics object from two tensors. """
69+
if actual.shape != reference.shape:
70+
return cls(
71+
reference_stats=TensorStatistics.from_tensor(reference),
72+
actual_stats=TensorStatistics.from_tensor(actual),
73+
error_l2_norm=None,
74+
error_mae=None,
75+
error_max = None,
76+
error_msd=None,
77+
sqnr=None,
78+
)
79+
80+
error = actual.to(torch.float64) - reference.to(torch.float64)
81+
flat_error = torch.flatten(error)
82+
83+
return cls(
84+
reference_stats=TensorStatistics.from_tensor(reference),
85+
actual_stats=TensorStatistics.from_tensor(actual),
86+
error_l2_norm=torch.linalg.norm(flat_error).item(),
87+
error_mae=torch.mean(torch.abs(flat_error)).item(),
88+
error_max=torch.max(torch.abs(flat_error)).item(),
89+
error_msd=torch.mean(flat_error).item(),
90+
sqnr=compute_sqnr(actual, reference).item()
91+
)

backends/test/harness/tester.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
78
from executorch.backends.test.harness.stages import (
89
Export,
910
Partition,
@@ -302,17 +303,15 @@ def run_method_and_compare_outputs(
302303
atol=1e-03,
303304
rtol=1e-03,
304305
qtol=0,
306+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
305307
):
306308
number_of_runs = 1 if inputs is not None else num_runs
307309
reference_stage = self.stages[StageType.EXPORT]
308310

309311
stage = stage or self.cur
310312

311-
print(f"Comparing Stage {stage} with Stage {reference_stage}")
312-
for run_iteration in range(number_of_runs):
313+
for _ in range(number_of_runs):
313314
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
314-
input_shapes = [generated_input.shape for generated_input in inputs_to_run]
315-
print(f"Run {run_iteration} with input shapes: {input_shapes}")
316315

317316
# Reference output (and quantization scale)
318317
(
@@ -325,13 +324,19 @@ def run_method_and_compare_outputs(
325324
# Output from running artifact at stage
326325
stage_output = self.stages[stage].run_artifact(inputs_to_run)
327326
self._compare_outputs(
328-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
327+
reference_output, stage_output, quantization_scale, atol, rtol, qtol, statistics_callback
329328
)
330329

331330
return self
332331

333332
@staticmethod
334-
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
333+
def _assert_outputs_equal(
334+
model_output,
335+
ref_output,
336+
atol=1e-03,
337+
rtol=1e-03,
338+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
339+
):
335340
"""
336341
Helper testing function that asserts that the model output and the reference output
337342
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -346,6 +351,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
346351
for i in range(len(model_output)):
347352
model = model_output[i]
348353
ref = ref_output[i]
354+
355+
error_stats = ErrorStatistics.from_tensors(model, ref)
356+
if statistics_callback is not None:
357+
statistics_callback(error_stats)
358+
349359
assert (
350360
ref.shape == model.shape
351361
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
@@ -383,6 +393,7 @@ def _compare_outputs(
383393
atol=1e-03,
384394
rtol=1e-03,
385395
qtol=0,
396+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
386397
):
387398
"""
388399
Compares the original of the original nn module with the output of the generated artifact.
@@ -399,12 +410,13 @@ def _compare_outputs(
399410
# atol by qtol quant units.
400411
if quantization_scale is not None:
401412
atol += quantization_scale * qtol
402-
413+
403414
Tester._assert_outputs_equal(
404415
stage_output,
405416
reference_output,
406417
atol=atol,
407418
rtol=rtol,
419+
statistics_callback=statistics_callback,
408420
)
409421

410422
@staticmethod

backends/test/suite/reporting.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from dataclasses import dataclass
33
from enum import IntEnum
44
from functools import reduce
5-
from re import A
65
from typing import TextIO
76

87
import csv
98

9+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
10+
1011
class TestResult(IntEnum):
1112
"""Represents the result of a test case run, indicating success or a specific failure reason."""
1213

@@ -100,6 +101,12 @@ class TestCaseSummary:
100101

101102
error: Exception | None
102103
""" The Python exception object, if any. """
104+
105+
tensor_error_statistics: list[ErrorStatistics]
106+
"""
107+
Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
108+
a single output tensor.
109+
"""
103110

104111

105112
class TestSessionState:
@@ -193,6 +200,17 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
193200
)
194201
field_names += (s.capitalize() for s in param_names)
195202

203+
# Add tensor error statistic field names for each output index.
204+
max_outputs = max(len(s.tensor_error_statistics) for s in summary.test_case_summaries)
205+
for i in range(max_outputs):
206+
field_names.extend([
207+
f"Output {i} Error Max",
208+
f"Output {i} Error MAE",
209+
f"Output {i} Error MSD",
210+
f"Output {i} Error L2",
211+
f"Output {i} SQNR",
212+
])
213+
196214
writer = csv.DictWriter(output, field_names)
197215
writer.writeheader()
198216

@@ -208,4 +226,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
208226
row.update({
209227
k.capitalize(): v for k, v in record.params.items()
210228
})
229+
230+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
231+
row[f"Output {output_idx} Error Max"] = error_stats.error_max
232+
row[f"Output {output_idx} Error MAE"] = error_stats.error_mae
233+
row[f"Output {output_idx} Error MSD"] = error_stats.error_msd
234+
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
235+
row[f"Output {output_idx} SQNR"] = error_stats.sqnr
236+
211237
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99

10+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
1011
from executorch.backends.test.harness.stages import StageType
1112
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
1213
from executorch.backends.test.suite.flow import TestFlow
@@ -40,6 +41,8 @@ def run_test( # noqa: C901
4041
Top-level test run function for a model, input set, and tester. Handles test execution
4142
and reporting.
4243
"""
44+
45+
error_statistics: list[ErrorStatistics] = []
4346

4447
# Helper method to construct the summary.
4548
def build_result(
@@ -53,6 +56,7 @@ def build_result(
5356
params=params,
5457
result=result,
5558
error=error,
59+
tensor_error_statistics=error_statistics,
5660
)
5761

5862
# Ensure the model can run in eager mode.
@@ -106,7 +110,9 @@ def build_result(
106110
# the cause of a failure in run_method_and_compare_outputs. We can look for
107111
# AssertionErrors to catch output mismatches, but this might catch more than that.
108112
try:
109-
tester.run_method_and_compare_outputs()
113+
tester.run_method_and_compare_outputs(
114+
statistics_callback = lambda stats: error_statistics.append(stats)
115+
)
110116
except AssertionError as e:
111117
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)
112118
except Exception as e:

0 commit comments

Comments
 (0)