Skip to content

[Backend Tester] Add tensor error statistic reporting #12809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: gh/GregoryComer/88/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions backends/test/harness/error_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from dataclasses import dataclass

import torch
from torch.ao.ns.fx.utils import compute_sqnr


@dataclass
class TensorStatistics:
"""Contains summary statistics for a tensor."""

shape: torch.Size
""" The shape of the tensor. """

numel: int
""" The number of elements in the tensor. """

median: float
""" The median of the tensor. """

mean: float
""" The mean of the tensor. """

max: torch.types.Number
""" The maximum element of the tensor. """

min: torch.types.Number
""" The minimum element of the tensor. """

@classmethod
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
"""Creates a TensorStatistics object from a tensor."""
flattened = torch.flatten(tensor)
return cls(
shape=tensor.shape,
numel=tensor.numel(),
median=torch.quantile(flattened, q=0.5).item(),
mean=flattened.mean().item(),
max=flattened.max().item(),
min=flattened.min().item(),
)


@dataclass
class ErrorStatistics:
"""Contains statistics derived from the difference of two tensors."""

reference_stats: TensorStatistics
""" Statistics for the reference tensor. """

actual_stats: TensorStatistics
""" Statistics for the actual tensor. """

error_l2_norm: float | None
""" The L2 norm of the error between the actual and reference tensor. """

error_mae: float | None
""" The mean absolute error between the actual and reference tensor. """

error_max: float | None
""" The maximum absolute elementwise error between the actual and reference tensor. """

error_msd: float | None
""" The mean signed deviation between the actual and reference tensor. """

sqnr: float | None
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """

@classmethod
def from_tensors(
cls, actual: torch.Tensor, reference: torch.Tensor
) -> "ErrorStatistics":
"""Creates an ErrorStatistics object from two tensors."""
actual = actual.to(torch.float64)
reference = reference.to(torch.float64)

if actual.shape != reference.shape:
return cls(
reference_stats=TensorStatistics.from_tensor(reference),
actual_stats=TensorStatistics.from_tensor(actual),
error_l2_norm=None,
error_mae=None,
error_max=None,
error_msd=None,
sqnr=None,
)

error = actual - reference
flat_error = torch.flatten(error)

return cls(
reference_stats=TensorStatistics.from_tensor(reference),
actual_stats=TensorStatistics.from_tensor(actual),
error_l2_norm=torch.linalg.norm(flat_error).item(),
error_mae=torch.mean(torch.abs(flat_error)).item(),
error_max=torch.max(torch.abs(flat_error)).item(),
error_msd=torch.mean(flat_error).item(),
# Torch sqnr implementation requires float32 due to decorator logic
sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(),
)
33 changes: 24 additions & 9 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from executorch.backends.test.harness.error_statistics import ErrorStatistics
from executorch.backends.test.harness.stages import (
Export,
Partition,
Expand Down Expand Up @@ -302,20 +303,15 @@ def run_method_and_compare_outputs(
atol=1e-03,
rtol=1e-03,
qtol=0,
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely happy with the callback approach for exposing this, but I don't really have a better idea, since the tester relies on a builder-style pattern where it returns self to allow chaining. I'm open to suggestions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just update the tester method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what you're thinking? Are you meaning update the tester run_method_and_compare outputs to directly return the error stats and then update all of the callers to not use it in a chained fashion? Or something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify the existing method, keeping the outside behavior but also add def get_comparison_stats(self) method on that stage or something?
Or if you want to pass a callback for flexibility that's also fine.

):
number_of_runs = 1 if inputs is not None else num_runs
reference_stage = self.stages[StageType.EXPORT]

stage = stage or self.cur

print(f"Comparing Stage {stage} with Stage {reference_stage}")
for run_iteration in range(number_of_runs):
for _ in range(number_of_runs):
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
input_shapes = [
generated_input.shape if hasattr(generated_input, "shape") else None
for generated_input in inputs_to_run
]
print(f"Run {run_iteration} with input shapes: {input_shapes}")

# Reference output (and quantization scale)
(
Expand All @@ -328,13 +324,25 @@ def run_method_and_compare_outputs(
# Output from running artifact at stage
stage_output = self.stages[stage].run_artifact(inputs_to_run)
self._compare_outputs(
reference_output, stage_output, quantization_scale, atol, rtol, qtol
reference_output,
stage_output,
quantization_scale,
atol,
rtol,
qtol,
statistics_callback,
)

return self

@staticmethod
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
def _assert_outputs_equal(
model_output,
ref_output,
atol=1e-03,
rtol=1e-03,
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
):
"""
Helper testing function that asserts that the model output and the reference output
are equal with some tolerance. Due to numerical differences between eager mode and
Expand All @@ -349,6 +357,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
for i in range(len(model_output)):
model = model_output[i]
ref = ref_output[i]

error_stats = ErrorStatistics.from_tensors(model, ref)
if statistics_callback is not None:
statistics_callback(error_stats)

assert (
ref.shape == model.shape
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
Expand Down Expand Up @@ -386,6 +399,7 @@ def _compare_outputs(
atol=1e-03,
rtol=1e-03,
qtol=0,
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
):
"""
Compares the original of the original nn module with the output of the generated artifact.
Expand All @@ -408,6 +422,7 @@ def _compare_outputs(
reference_output,
atol=atol,
rtol=rtol,
statistics_callback=statistics_callback,
)

@staticmethod
Expand Down
65 changes: 65 additions & 0 deletions backends/test/harness/tests/test_error_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest

import torch
from executorch.backends.test.harness.error_statistics import ErrorStatistics


class ErrorStatisticsTests(unittest.TestCase):
def test_error_stats_simple(self):
tensor1 = torch.tensor([1, 2, 3, 4])
tensor2 = torch.tensor([2, 2, 2, 5])

error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)

# Check actual tensor statistics
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
self.assertEqual(error_stats.actual_stats.numel, 4)
self.assertEqual(error_stats.actual_stats.median, 2.5)
self.assertEqual(error_stats.actual_stats.mean, 2.5)
self.assertEqual(error_stats.actual_stats.max, 4)
self.assertEqual(error_stats.actual_stats.min, 1)

# Check reference tensor statistics
self.assertEqual(error_stats.reference_stats.shape, torch.Size([4]))
self.assertEqual(error_stats.reference_stats.numel, 4)
self.assertEqual(error_stats.reference_stats.median, 2.0)
self.assertEqual(error_stats.reference_stats.mean, 2.75)
self.assertEqual(error_stats.reference_stats.max, 5)
self.assertEqual(error_stats.reference_stats.min, 2)

# Check error statistics
self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3)
self.assertEqual(error_stats.error_mae, 0.75)
self.assertEqual(error_stats.error_max, 1.0)
self.assertEqual(error_stats.error_msd, -0.25)
self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3)

def test_error_stats_different_shapes(self):
# Create tensors with different shapes
tensor1 = torch.tensor([1, 2, 3, 4])
tensor2 = torch.tensor([[2, 3], [4, 5]])

error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)

# Check actual tensor statistics
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
self.assertEqual(error_stats.actual_stats.numel, 4)
self.assertEqual(error_stats.actual_stats.median, 2.5)
self.assertEqual(error_stats.actual_stats.mean, 2.5)
self.assertEqual(error_stats.actual_stats.max, 4)
self.assertEqual(error_stats.actual_stats.min, 1)

# Check reference tensor statistics
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2]))
self.assertEqual(error_stats.reference_stats.numel, 4)
self.assertEqual(error_stats.reference_stats.median, 3.5)
self.assertEqual(error_stats.reference_stats.mean, 3.5)
self.assertEqual(error_stats.reference_stats.max, 5)
self.assertEqual(error_stats.reference_stats.min, 2)

# Check that all error values are None when shapes differ
self.assertIsNone(error_stats.error_l2_norm)
self.assertIsNone(error_stats.error_mae)
self.assertIsNone(error_stats.error_max)
self.assertIsNone(error_stats.error_msd)
self.assertIsNone(error_stats.sqnr)
31 changes: 31 additions & 0 deletions backends/test/suite/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import reduce
from typing import TextIO

from executorch.backends.test.harness.error_statistics import ErrorStatistics


class TestResult(IntEnum):
"""Represents the result of a test case run, indicating success or a specific failure reason."""
Expand Down Expand Up @@ -100,6 +102,12 @@ class TestCaseSummary:
error: Exception | None
""" The Python exception object, if any. """

tensor_error_statistics: list[ErrorStatistics]
"""
Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
a single output tensor.
"""


class TestSessionState:
test_case_summaries: list[TestCaseSummary]
Expand Down Expand Up @@ -197,6 +205,21 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
)
field_names += (s.capitalize() for s in param_names)

# Add tensor error statistic field names for each output index.
max_outputs = max(
len(s.tensor_error_statistics) for s in summary.test_case_summaries
)
for i in range(max_outputs):
field_names.extend(
[
f"Output {i} Error Max",
f"Output {i} Error MAE",
f"Output {i} Error MSD",
f"Output {i} Error L2",
f"Output {i} SQNR",
]
)

writer = csv.DictWriter(output, field_names)
writer.writeheader()

Expand All @@ -210,4 +233,12 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
}
if record.params is not None:
row.update({k.capitalize(): v for k, v in record.params.items()})

for output_idx, error_stats in enumerate(record.tensor_error_statistics):
row[f"Output {output_idx} Error Max"] = error_stats.error_max
row[f"Output {output_idx} Error MAE"] = error_stats.error_mae
row[f"Output {output_idx} Error MSD"] = error_stats.error_msd
row[f"Output {output_idx} Error L2"] = error_stats.error_l2_norm
row[f"Output {output_idx} SQNR"] = error_stats.sqnr

writer.writerow(row)
7 changes: 6 additions & 1 deletion backends/test/suite/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

from executorch.backends.test.harness.error_statistics import ErrorStatistics
from executorch.backends.test.harness.stages import StageType
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
from executorch.backends.test.suite.flow import TestFlow
Expand Down Expand Up @@ -42,6 +43,8 @@ def run_test( # noqa: C901
and reporting.
"""

error_statistics: list[ErrorStatistics] = []

# Helper method to construct the summary.
def build_result(
result: TestResult, error: Exception | None = None
Expand All @@ -54,6 +57,7 @@ def build_result(
params=params,
result=result,
error=error,
tensor_error_statistics=error_statistics,
)

# Ensure the model can run in eager mode.
Expand Down Expand Up @@ -108,7 +112,8 @@ def build_result(
# AssertionErrors to catch output mismatches, but this might catch more than that.
try:
tester.run_method_and_compare_outputs(
inputs=None if generate_random_test_inputs else inputs
inputs=None if generate_random_test_inputs else inputs,
statistics_callback=lambda stats: error_statistics.append(stats),
)
except AssertionError as e:
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)
Expand Down
Loading