diff --git a/backends/test/harness/error_statistics.py b/backends/test/harness/error_statistics.py new file mode 100644 index 00000000000..db0ab7e3dd0 --- /dev/null +++ b/backends/test/harness/error_statistics.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass + +import torch +from torch.ao.ns.fx.utils import compute_sqnr + + +@dataclass +class TensorStatistics: + """Contains summary statistics for a tensor.""" + + shape: torch.Size + """ The shape of the tensor. """ + + numel: int + """ The number of elements in the tensor. """ + + median: float + """ The median of the tensor. """ + + mean: float + """ The mean of the tensor. """ + + max: torch.types.Number + """ The maximum element of the tensor. """ + + min: torch.types.Number + """ The minimum element of the tensor. """ + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics": + """Creates a TensorStatistics object from a tensor.""" + flattened = torch.flatten(tensor) + return cls( + shape=tensor.shape, + numel=tensor.numel(), + median=torch.quantile(flattened, q=0.5).item(), + mean=flattened.mean().item(), + max=flattened.max().item(), + min=flattened.min().item(), + ) + + +@dataclass +class ErrorStatistics: + """Contains statistics derived from the difference of two tensors.""" + + reference_stats: TensorStatistics + """ Statistics for the reference tensor. """ + + actual_stats: TensorStatistics + """ Statistics for the actual tensor. """ + + error_l2_norm: float | None + """ The L2 norm of the error between the actual and reference tensor. """ + + error_mae: float | None + """ The mean absolute error between the actual and reference tensor. """ + + error_max: float | None + """ The maximum absolute elementwise error between the actual and reference tensor. """ + + error_msd: float | None + """ The mean signed deviation between the actual and reference tensor. """ + + sqnr: float | None + """ The signal-to-quantization-noise ratio between the actual and reference tensor. """ + + @classmethod + def from_tensors( + cls, actual: torch.Tensor, reference: torch.Tensor + ) -> "ErrorStatistics": + """Creates an ErrorStatistics object from two tensors.""" + actual = actual.to(torch.float64) + reference = reference.to(torch.float64) + + if actual.shape != reference.shape: + return cls( + reference_stats=TensorStatistics.from_tensor(reference), + actual_stats=TensorStatistics.from_tensor(actual), + error_l2_norm=None, + error_mae=None, + error_max=None, + error_msd=None, + sqnr=None, + ) + + error = actual - reference + flat_error = torch.flatten(error) + + return cls( + reference_stats=TensorStatistics.from_tensor(reference), + actual_stats=TensorStatistics.from_tensor(actual), + error_l2_norm=torch.linalg.norm(flat_error).item(), + error_mae=torch.mean(torch.abs(flat_error)).item(), + error_max=torch.max(torch.abs(flat_error)).item(), + error_msd=torch.mean(flat_error).item(), + # Torch sqnr implementation requires float32 due to decorator logic + sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(), + ) diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index 7019b734290..2782fc7bb29 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -4,6 +4,7 @@ import torch +from executorch.backends.test.harness.error_statistics import ErrorStatistics from executorch.backends.test.harness.stages import ( Export, Partition, @@ -302,20 +303,15 @@ def run_method_and_compare_outputs( atol=1e-03, rtol=1e-03, qtol=0, + statistics_callback: Callable[[ErrorStatistics], None] | None = None, ): number_of_runs = 1 if inputs is not None else num_runs reference_stage = self.stages[StageType.EXPORT] stage = stage or self.cur - print(f"Comparing Stage {stage} with Stage {reference_stage}") - for run_iteration in range(number_of_runs): + for _ in range(number_of_runs): inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) - input_shapes = [ - generated_input.shape if hasattr(generated_input, "shape") else None - for generated_input in inputs_to_run - ] - print(f"Run {run_iteration} with input shapes: {input_shapes}") # Reference output (and quantization scale) ( @@ -328,13 +324,25 @@ def run_method_and_compare_outputs( # Output from running artifact at stage stage_output = self.stages[stage].run_artifact(inputs_to_run) self._compare_outputs( - reference_output, stage_output, quantization_scale, atol, rtol, qtol + reference_output, + stage_output, + quantization_scale, + atol, + rtol, + qtol, + statistics_callback, ) return self @staticmethod - def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): + def _assert_outputs_equal( + model_output, + ref_output, + atol=1e-03, + rtol=1e-03, + statistics_callback: Callable[[ErrorStatistics], None] | None = None, + ): """ Helper testing function that asserts that the model output and the reference output are equal with some tolerance. Due to numerical differences between eager mode and @@ -349,6 +357,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): for i in range(len(model_output)): model = model_output[i] ref = ref_output[i] + + error_stats = ErrorStatistics.from_tensors(model, ref) + if statistics_callback is not None: + statistics_callback(error_stats) + assert ( ref.shape == model.shape ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" @@ -386,6 +399,7 @@ def _compare_outputs( atol=1e-03, rtol=1e-03, qtol=0, + statistics_callback: Callable[[ErrorStatistics], None] | None = None, ): """ Compares the original of the original nn module with the output of the generated artifact. @@ -408,6 +422,7 @@ def _compare_outputs( reference_output, atol=atol, rtol=rtol, + statistics_callback=statistics_callback, ) @staticmethod diff --git a/backends/test/harness/tests/test_error_statistics.py b/backends/test/harness/tests/test_error_statistics.py new file mode 100644 index 00000000000..fdff9c75b00 --- /dev/null +++ b/backends/test/harness/tests/test_error_statistics.py @@ -0,0 +1,65 @@ +import unittest + +import torch +from executorch.backends.test.harness.error_statistics import ErrorStatistics + + +class ErrorStatisticsTests(unittest.TestCase): + def test_error_stats_simple(self): + tensor1 = torch.tensor([1, 2, 3, 4]) + tensor2 = torch.tensor([2, 2, 2, 5]) + + error_stats = ErrorStatistics.from_tensors(tensor1, tensor2) + + # Check actual tensor statistics + self.assertEqual(error_stats.actual_stats.shape, torch.Size([4])) + self.assertEqual(error_stats.actual_stats.numel, 4) + self.assertEqual(error_stats.actual_stats.median, 2.5) + self.assertEqual(error_stats.actual_stats.mean, 2.5) + self.assertEqual(error_stats.actual_stats.max, 4) + self.assertEqual(error_stats.actual_stats.min, 1) + + # Check reference tensor statistics + self.assertEqual(error_stats.reference_stats.shape, torch.Size([4])) + self.assertEqual(error_stats.reference_stats.numel, 4) + self.assertEqual(error_stats.reference_stats.median, 2.0) + self.assertEqual(error_stats.reference_stats.mean, 2.75) + self.assertEqual(error_stats.reference_stats.max, 5) + self.assertEqual(error_stats.reference_stats.min, 2) + + # Check error statistics + self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3) + self.assertEqual(error_stats.error_mae, 0.75) + self.assertEqual(error_stats.error_max, 1.0) + self.assertEqual(error_stats.error_msd, -0.25) + self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3) + + def test_error_stats_different_shapes(self): + # Create tensors with different shapes + tensor1 = torch.tensor([1, 2, 3, 4]) + tensor2 = torch.tensor([[2, 3], [4, 5]]) + + error_stats = ErrorStatistics.from_tensors(tensor1, tensor2) + + # Check actual tensor statistics + self.assertEqual(error_stats.actual_stats.shape, torch.Size([4])) + self.assertEqual(error_stats.actual_stats.numel, 4) + self.assertEqual(error_stats.actual_stats.median, 2.5) + self.assertEqual(error_stats.actual_stats.mean, 2.5) + self.assertEqual(error_stats.actual_stats.max, 4) + self.assertEqual(error_stats.actual_stats.min, 1) + + # Check reference tensor statistics + self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2])) + self.assertEqual(error_stats.reference_stats.numel, 4) + self.assertEqual(error_stats.reference_stats.median, 3.5) + self.assertEqual(error_stats.reference_stats.mean, 3.5) + self.assertEqual(error_stats.reference_stats.max, 5) + self.assertEqual(error_stats.reference_stats.min, 2) + + # Check that all error values are None when shapes differ + self.assertIsNone(error_stats.error_l2_norm) + self.assertIsNone(error_stats.error_mae) + self.assertIsNone(error_stats.error_max) + self.assertIsNone(error_stats.error_msd) + self.assertIsNone(error_stats.sqnr) diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index 06c8ea952db..15c19bf7c8e 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -5,6 +5,8 @@ from functools import reduce from typing import TextIO +from executorch.backends.test.harness.error_statistics import ErrorStatistics + class TestResult(IntEnum): """Represents the result of a test case run, indicating success or a specific failure reason.""" @@ -100,6 +102,12 @@ class TestCaseSummary: error: Exception | None """ The Python exception object, if any. """ + tensor_error_statistics: list[ErrorStatistics] + """ + Statistics about the error between the backend and reference outputs. Each element of this list corresponds to + a single output tensor. + """ + class TestSessionState: test_case_summaries: list[TestCaseSummary] @@ -197,6 +205,21 @@ def generate_csv_report(summary: RunSummary, output: TextIO): ) field_names += (s.capitalize() for s in param_names) + # Add tensor error statistic field names for each output index. + max_outputs = max( + len(s.tensor_error_statistics) for s in summary.test_case_summaries + ) + for i in range(max_outputs): + field_names.extend( + [ + f"Output {i} Error Max", + f"Output {i} Error MAE", + f"Output {i} Error MSD", + f"Output {i} Error L2", + f"Output {i} SQNR", + ] + ) + writer = csv.DictWriter(output, field_names) writer.writeheader() @@ -210,4 +233,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO): } if record.params is not None: row.update({k.capitalize(): v for k, v in record.params.items()}) + + for output_idx, error_stats in enumerate(record.tensor_error_statistics): + row[f"Output {output_idx} Error Max"] = error_stats.error_max + row[f"Output {output_idx} Error MAE"] = error_stats.error_mae + row[f"Output {output_idx} Error MSD"] = error_stats.error_msd + row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm + row[f"Output {output_idx} SQNR"] = error_stats.sqnr + writer.writerow(row) diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 59c4c4a33a4..6655cf9653b 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -7,6 +7,7 @@ import torch +from executorch.backends.test.harness.error_statistics import ErrorStatistics from executorch.backends.test.harness.stages import StageType from executorch.backends.test.suite.discovery import discover_tests, TestFilter from executorch.backends.test.suite.flow import TestFlow @@ -42,6 +43,8 @@ def run_test( # noqa: C901 and reporting. """ + error_statistics: list[ErrorStatistics] = [] + # Helper method to construct the summary. def build_result( result: TestResult, error: Exception | None = None @@ -54,6 +57,7 @@ def build_result( params=params, result=result, error=error, + tensor_error_statistics=error_statistics, ) # Ensure the model can run in eager mode. @@ -108,7 +112,8 @@ def build_result( # AssertionErrors to catch output mismatches, but this might catch more than that. try: tester.run_method_and_compare_outputs( - inputs=None if generate_random_test_inputs else inputs + inputs=None if generate_random_test_inputs else inputs, + statistics_callback=lambda stats: error_statistics.append(stats), ) except AssertionError as e: return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)