-
Notifications
You must be signed in to change notification settings - Fork 646
[Backend Tester] Add tensor error statistic reporting #12809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 76 commits
f120e70
0fb85e6
4d8d844
dc12b40
ead0616
0f13676
b0b01f2
8b9c9ef
06bf03a
2f8f49b
8ca7766
bffb95f
d21492b
e2c4ea5
8230848
2a1f564
b35e7b1
5c4c6ce
9397803
9dfeb5a
ff5c4a5
42a5de5
402d8f5
34d3ab3
1105e04
482bd21
ea548b7
4108f54
7ef236b
4a58c9d
3b866b4
5ba25cb
19760fc
81dfb07
4d50265
5f66043
24e919d
523cc20
74c95fe
5d437b1
89757ce
423f79a
69f7f9c
c0f6224
e2ea2a3
7a2fab5
033c231
a9ed762
64b174a
3976629
27cd171
7bdd3e5
b1254cd
f2e2289
cdd15c1
e2df06e
4461bd8
7e97fd0
bcb697c
11a5a02
244b146
de21ac2
fd26fc7
4ae840d
710ea49
32f54b0
a27d18c
2eb59fc
5cc4941
ef7af5c
dd09555
f1db3a0
e0700b2
f261355
c3a24f9
1697cbc
7e1a002
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, List, Optional, Tuple | ||
|
||
import executorch | ||
import executorch.backends.test.harness.stages as BaseStages | ||
|
||
import torch | ||
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager | ||
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner | ||
from executorch.backends.qualcomm.utils.utils import ( | ||
generate_htp_compiler_spec, | ||
generate_qnn_executorch_compiler_spec, | ||
get_soc_to_chipset_map, | ||
) | ||
from executorch.backends.test.harness import Tester as TesterBase | ||
from executorch.backends.test.harness.stages import StageType | ||
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower | ||
from executorch.exir.backend.partitioner import Partitioner | ||
from torch.export import ExportedProgram | ||
|
||
|
||
class Partition(BaseStages.Partition): | ||
def __init__(self, partitioner: Optional[Partitioner] = None): | ||
super().__init__( | ||
partitioner=partitioner or QnnPartitioner, | ||
) | ||
|
||
|
||
class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower): | ||
def __init__( | ||
self, | ||
partitioners: Optional[List[Partitioner]] = None, | ||
edge_compile_config: Optional[EdgeCompileConfig] = None, | ||
soc_model: str = "SM8650", | ||
): | ||
backend_options = generate_htp_compiler_spec(use_fp16=True) | ||
self.chipset = get_soc_to_chipset_map()[soc_model] | ||
self.compiler_specs = generate_qnn_executorch_compiler_spec( | ||
soc_model=self.chipset, | ||
backend_options=backend_options, | ||
) | ||
|
||
super().__init__( | ||
partitioners=partitioners or [QnnPartitioner(self.compiler_specs)], | ||
edge_compile_config=edge_compile_config | ||
or EdgeCompileConfig(_check_ir_validity=False), | ||
default_partitioner_cls=QnnPartitioner, | ||
) | ||
|
||
def run(self, artifact: ExportedProgram, inputs=None) -> None: | ||
ep = QnnPassManager().transform_for_export_pipeline(artifact) | ||
transform_passes = QnnPassManager().get_to_edge_transform_passes(ep) | ||
|
||
self.edge_dialect_program = to_edge_transform_and_lower( | ||
ep, | ||
transform_passes=transform_passes, | ||
partitioner=self.partitioners, | ||
compile_config=self.edge_compile_conf, | ||
) | ||
|
||
|
||
class QualcommTester(TesterBase): | ||
def __init__( | ||
self, | ||
module: torch.nn.Module, | ||
example_inputs: Tuple[torch.Tensor], | ||
dynamic_shapes: Optional[Tuple[Any]] = None, | ||
): | ||
# Specialize for Qualcomm | ||
stage_classes = ( | ||
executorch.backends.test.harness.Tester.default_stage_classes() | ||
| { | ||
StageType.PARTITION: Partition, | ||
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, | ||
} | ||
) | ||
|
||
super().__init__( | ||
module=module, | ||
stage_classes=stage_classes, | ||
example_inputs=example_inputs, | ||
dynamic_shapes=dynamic_shapes, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
from torch.ao.ns.fx.utils import compute_sqnr | ||
|
||
|
||
@dataclass | ||
class TensorStatistics: | ||
"""Contains summary statistics for a tensor.""" | ||
|
||
shape: torch.Size | ||
""" The shape of the tensor. """ | ||
|
||
numel: int | ||
""" The number of elements in the tensor. """ | ||
|
||
median: float | ||
""" The median of the tensor. """ | ||
|
||
mean: float | ||
""" The mean of the tensor. """ | ||
|
||
max: torch.types.Number | ||
""" The maximum element of the tensor. """ | ||
|
||
min: torch.types.Number | ||
""" The minimum element of the tensor. """ | ||
|
||
@classmethod | ||
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics": | ||
"""Creates a TensorStatistics object from a tensor.""" | ||
flattened = torch.flatten(tensor) | ||
return cls( | ||
shape=tensor.shape, | ||
numel=tensor.numel(), | ||
median=torch.quantile(flattened, q=0.5).item(), | ||
mean=flattened.mean().item(), | ||
max=flattened.max().item(), | ||
min=flattened.min().item(), | ||
) | ||
|
||
|
||
@dataclass | ||
class ErrorStatistics: | ||
"""Contains statistics derived from the difference of two tensors.""" | ||
|
||
reference_stats: TensorStatistics | ||
""" Statistics for the reference tensor. """ | ||
|
||
actual_stats: TensorStatistics | ||
""" Statistics for the actual tensor. """ | ||
|
||
error_l2_norm: float | None | ||
""" The L2 norm of the error between the actual and reference tensor. """ | ||
|
||
error_mae: float | None | ||
""" The mean absolute error between the actual and reference tensor. """ | ||
|
||
error_max: float | None | ||
""" The maximum absolute elementwise error between the actual and reference tensor. """ | ||
|
||
error_msd: float | None | ||
""" The mean signed deviation between the actual and reference tensor. """ | ||
|
||
sqnr: float | None | ||
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """ | ||
|
||
@classmethod | ||
def from_tensors( | ||
cls, actual: torch.Tensor, reference: torch.Tensor | ||
) -> "ErrorStatistics": | ||
"""Creates an ErrorStatistics object from two tensors.""" | ||
actual = actual.to(torch.float64) | ||
reference = reference.to(torch.float64) | ||
|
||
if actual.shape != reference.shape: | ||
return cls( | ||
reference_stats=TensorStatistics.from_tensor(reference), | ||
actual_stats=TensorStatistics.from_tensor(actual), | ||
error_l2_norm=None, | ||
error_mae=None, | ||
error_max=None, | ||
error_msd=None, | ||
sqnr=None, | ||
) | ||
|
||
error = actual - reference | ||
flat_error = torch.flatten(error) | ||
|
||
return cls( | ||
reference_stats=TensorStatistics.from_tensor(reference), | ||
actual_stats=TensorStatistics.from_tensor(actual), | ||
error_l2_norm=torch.linalg.norm(flat_error).item(), | ||
error_mae=torch.mean(torch.abs(flat_error)).item(), | ||
error_max=torch.max(torch.abs(flat_error)).item(), | ||
error_msd=torch.mean(flat_error).item(), | ||
# Torch sqnr implementation requires float32 due to decorator logic | ||
sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
import torch | ||
|
||
from executorch.backends.test.harness.error_statistics import ErrorStatistics | ||
from executorch.backends.test.harness.stages import ( | ||
Export, | ||
Partition, | ||
|
@@ -302,20 +303,15 @@ def run_method_and_compare_outputs( | |
atol=1e-03, | ||
rtol=1e-03, | ||
qtol=0, | ||
statistics_callback: Callable[[ErrorStatistics], None] | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not completely happy with the callback approach for exposing this, but I don't really have a better idea, since the tester relies on a builder-style pattern where it returns self to allow chaining. I'm open to suggestions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just update the tester method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify what you're thinking? Are you meaning update the tester run_method_and_compare outputs to directly return the error stats and then update all of the callers to not use it in a chained fashion? Or something else? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modify the existing method, keeping the outside behavior but also add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filed as #13337. Will stack this on top. |
||
): | ||
number_of_runs = 1 if inputs is not None else num_runs | ||
reference_stage = self.stages[StageType.EXPORT] | ||
|
||
stage = stage or self.cur | ||
|
||
print(f"Comparing Stage {stage} with Stage {reference_stage}") | ||
for run_iteration in range(number_of_runs): | ||
for _ in range(number_of_runs): | ||
inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) | ||
input_shapes = [ | ||
generated_input.shape if hasattr(generated_input, "shape") else None | ||
for generated_input in inputs_to_run | ||
] | ||
print(f"Run {run_iteration} with input shapes: {input_shapes}") | ||
|
||
# Reference output (and quantization scale) | ||
( | ||
|
@@ -328,13 +324,25 @@ def run_method_and_compare_outputs( | |
# Output from running artifact at stage | ||
stage_output = self.stages[stage].run_artifact(inputs_to_run) | ||
self._compare_outputs( | ||
reference_output, stage_output, quantization_scale, atol, rtol, qtol | ||
reference_output, | ||
stage_output, | ||
quantization_scale, | ||
atol, | ||
rtol, | ||
qtol, | ||
statistics_callback, | ||
) | ||
|
||
return self | ||
|
||
@staticmethod | ||
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): | ||
def _assert_outputs_equal( | ||
model_output, | ||
ref_output, | ||
atol=1e-03, | ||
rtol=1e-03, | ||
statistics_callback: Callable[[ErrorStatistics], None] | None = None, | ||
): | ||
""" | ||
Helper testing function that asserts that the model output and the reference output | ||
are equal with some tolerance. Due to numerical differences between eager mode and | ||
|
@@ -349,6 +357,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): | |
for i in range(len(model_output)): | ||
model = model_output[i] | ||
ref = ref_output[i] | ||
|
||
error_stats = ErrorStatistics.from_tensors(model, ref) | ||
if statistics_callback is not None: | ||
statistics_callback(error_stats) | ||
|
||
assert ( | ||
ref.shape == model.shape | ||
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" | ||
|
@@ -386,6 +399,7 @@ def _compare_outputs( | |
atol=1e-03, | ||
rtol=1e-03, | ||
qtol=0, | ||
statistics_callback: Callable[[ErrorStatistics], None] | None = None, | ||
): | ||
""" | ||
Compares the original of the original nn module with the output of the generated artifact. | ||
|
@@ -408,6 +422,7 @@ def _compare_outputs( | |
reference_output, | ||
atol=atol, | ||
rtol=rtol, | ||
statistics_callback=statistics_callback, | ||
) | ||
|
||
@staticmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import unittest | ||
|
||
import torch | ||
from executorch.backends.test.harness.error_statistics import ErrorStatistics | ||
|
||
|
||
class ErrorStatisticsTests(unittest.TestCase): | ||
def test_error_stats_simple(self): | ||
tensor1 = torch.tensor([1, 2, 3, 4]) | ||
tensor2 = torch.tensor([2, 2, 2, 5]) | ||
|
||
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2) | ||
|
||
# Check actual tensor statistics | ||
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4])) | ||
self.assertEqual(error_stats.actual_stats.numel, 4) | ||
self.assertEqual(error_stats.actual_stats.median, 2.5) | ||
self.assertEqual(error_stats.actual_stats.mean, 2.5) | ||
self.assertEqual(error_stats.actual_stats.max, 4) | ||
self.assertEqual(error_stats.actual_stats.min, 1) | ||
|
||
# Check reference tensor statistics | ||
self.assertEqual(error_stats.reference_stats.shape, torch.Size([4])) | ||
self.assertEqual(error_stats.reference_stats.numel, 4) | ||
self.assertEqual(error_stats.reference_stats.median, 2.0) | ||
self.assertEqual(error_stats.reference_stats.mean, 2.75) | ||
self.assertEqual(error_stats.reference_stats.max, 5) | ||
self.assertEqual(error_stats.reference_stats.min, 2) | ||
|
||
# Check error statistics | ||
self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3) | ||
self.assertEqual(error_stats.error_mae, 0.75) | ||
self.assertEqual(error_stats.error_max, 1.0) | ||
self.assertEqual(error_stats.error_msd, -0.25) | ||
self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3) | ||
|
||
def test_error_stats_different_shapes(self): | ||
# Create tensors with different shapes | ||
tensor1 = torch.tensor([1, 2, 3, 4]) | ||
tensor2 = torch.tensor([[2, 3], [4, 5]]) | ||
|
||
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2) | ||
|
||
# Check actual tensor statistics | ||
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4])) | ||
self.assertEqual(error_stats.actual_stats.numel, 4) | ||
self.assertEqual(error_stats.actual_stats.median, 2.5) | ||
self.assertEqual(error_stats.actual_stats.mean, 2.5) | ||
self.assertEqual(error_stats.actual_stats.max, 4) | ||
self.assertEqual(error_stats.actual_stats.min, 1) | ||
|
||
# Check reference tensor statistics | ||
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2])) | ||
self.assertEqual(error_stats.reference_stats.numel, 4) | ||
self.assertEqual(error_stats.reference_stats.median, 3.5) | ||
self.assertEqual(error_stats.reference_stats.mean, 3.5) | ||
self.assertEqual(error_stats.reference_stats.max, 5) | ||
self.assertEqual(error_stats.reference_stats.min, 2) | ||
|
||
# Check that all error values are None when shapes differ | ||
self.assertIsNone(error_stats.error_l2_norm) | ||
self.assertIsNone(error_stats.error_mae) | ||
self.assertIsNone(error_stats.error_max) | ||
self.assertIsNone(error_stats.error_msd) | ||
self.assertIsNone(error_stats.sqnr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this not an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some cases where in-place ops get functionalized and alter the graph outputs. This is an issue with the defunctionalization logic in ET (there's a separate issue I filed). It's not technically the backend's fault but is a real issue. I should probably just disable the affected tests for now and treat this as an error. I'll do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filed #13336 as a follow-up. Will stack this change on top.