Skip to content

Commit a27d18c

Browse files
committed
Update
[ghstack-poisoned]
2 parents fd26fc7 + 32f54b0 commit a27d18c

File tree

12 files changed

+106
-69
lines changed

12 files changed

+106
-69
lines changed

backends/qualcomm/tests/tester.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
1414
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
1515
from executorch.backends.qualcomm.utils.utils import (
16-
generate_qnn_executorch_compiler_spec,
1716
generate_htp_compiler_spec,
17+
generate_qnn_executorch_compiler_spec,
1818
get_soc_to_chipset_map,
1919
)
2020
from executorch.backends.test.harness import Tester as TesterBase
@@ -36,7 +36,7 @@ def __init__(
3636
self,
3737
partitioners: Optional[List[Partitioner]] = None,
3838
edge_compile_config: Optional[EdgeCompileConfig] = None,
39-
soc_model: str = "SM8650"
39+
soc_model: str = "SM8650",
4040
):
4141
backend_options = generate_htp_compiler_spec(use_fp16=True)
4242
self.chipset = get_soc_to_chipset_map()[soc_model]
@@ -47,7 +47,8 @@ def __init__(
4747

4848
super().__init__(
4949
partitioners=partitioners or [QnnPartitioner(self.compiler_specs)],
50-
edge_compile_config=edge_compile_config or EdgeCompileConfig(_check_ir_validity=False),
50+
edge_compile_config=edge_compile_config
51+
or EdgeCompileConfig(_check_ir_validity=False),
5152
default_partitioner_cls=QnnPartitioner,
5253
)
5354

@@ -69,7 +70,7 @@ def __init__(
6970
module: torch.nn.Module,
7071
example_inputs: Tuple[torch.Tensor],
7172
dynamic_shapes: Optional[Tuple[Any]] = None,
72-
):
73+
):
7374
# Specialize for Qualcomm
7475
stage_classes = (
7576
executorch.backends.test.harness.Tester.default_stage_classes()

backends/test/harness/error_statistics.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from dataclasses import dataclass
2-
from torch.ao.ns.fx.utils import compute_sqnr
32

43
import torch
4+
from torch.ao.ns.fx.utils import compute_sqnr
5+
56

67
@dataclass
78
class TensorStatistics:
8-
""" Contains summary statistics for a tensor. """
9+
"""Contains summary statistics for a tensor."""
910

1011
shape: torch.Size
1112
""" The shape of the tensor. """
12-
13+
1314
numel: int
1415
""" The number of elements in the tensor. """
1516

@@ -24,10 +25,10 @@ class TensorStatistics:
2425

2526
min: torch.types.Number
2627
""" The minimum element of the tensor. """
27-
28+
2829
@classmethod
2930
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
30-
""" Creates a TensorStatistics object from a tensor. """
31+
"""Creates a TensorStatistics object from a tensor."""
3132
flattened = torch.flatten(tensor)
3233
return cls(
3334
shape=tensor.shape,
@@ -38,41 +39,44 @@ def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
3839
min=flattened.min().item(),
3940
)
4041

42+
4143
@dataclass
4244
class ErrorStatistics:
43-
""" Contains statistics derived from the difference of two tensors. """
45+
"""Contains statistics derived from the difference of two tensors."""
4446

45-
reference_stats: TensorStatistics
47+
reference_stats: TensorStatistics
4648
""" Statistics for the reference tensor. """
4749

4850
actual_stats: TensorStatistics
4951
""" Statistics for the actual tensor. """
50-
52+
5153
error_l2_norm: float | None
5254
""" The L2 norm of the error between the actual and reference tensor. """
53-
55+
5456
error_mae: float | None
5557
""" The mean absolute error between the actual and reference tensor. """
56-
58+
5759
error_max: float | None
5860
""" The maximum absolute elementwise error between the actual and reference tensor. """
59-
61+
6062
error_msd: float | None
6163
""" The mean signed deviation between the actual and reference tensor. """
62-
64+
6365
sqnr: float | None
6466
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """
6567

6668
@classmethod
67-
def from_tensors(cls, actual: torch.Tensor, reference: torch.Tensor) -> "ErrorStatistics":
68-
""" Creates an ErrorStatistics object from two tensors. """
69+
def from_tensors(
70+
cls, actual: torch.Tensor, reference: torch.Tensor
71+
) -> "ErrorStatistics":
72+
"""Creates an ErrorStatistics object from two tensors."""
6973
if actual.shape != reference.shape:
7074
return cls(
7175
reference_stats=TensorStatistics.from_tensor(reference),
7276
actual_stats=TensorStatistics.from_tensor(actual),
7377
error_l2_norm=None,
7478
error_mae=None,
75-
error_max = None,
79+
error_max=None,
7680
error_msd=None,
7781
sqnr=None,
7882
)
@@ -87,5 +91,5 @@ def from_tensors(cls, actual: torch.Tensor, reference: torch.Tensor) -> "ErrorSt
8791
error_mae=torch.mean(torch.abs(flat_error)).item(),
8892
error_max=torch.max(torch.abs(flat_error)).item(),
8993
error_msd=torch.mean(flat_error).item(),
90-
sqnr=compute_sqnr(actual, reference).item()
94+
sqnr=compute_sqnr(actual, reference).item(),
9195
)

backends/test/harness/tester.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,16 +324,22 @@ def run_method_and_compare_outputs(
324324
# Output from running artifact at stage
325325
stage_output = self.stages[stage].run_artifact(inputs_to_run)
326326
self._compare_outputs(
327-
reference_output, stage_output, quantization_scale, atol, rtol, qtol, statistics_callback
327+
reference_output,
328+
stage_output,
329+
quantization_scale,
330+
atol,
331+
rtol,
332+
qtol,
333+
statistics_callback,
328334
)
329335

330336
return self
331337

332338
@staticmethod
333339
def _assert_outputs_equal(
334-
model_output,
335-
ref_output,
336-
atol=1e-03,
340+
model_output,
341+
ref_output,
342+
atol=1e-03,
337343
rtol=1e-03,
338344
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
339345
):
@@ -351,7 +357,7 @@ def _assert_outputs_equal(
351357
for i in range(len(model_output)):
352358
model = model_output[i]
353359
ref = ref_output[i]
354-
360+
355361
error_stats = ErrorStatistics.from_tensors(model, ref)
356362
if statistics_callback is not None:
357363
statistics_callback(error_stats)
@@ -410,7 +416,7 @@ def _compare_outputs(
410416
# atol by qtol quant units.
411417
if quantization_scale is not None:
412418
atol += quantization_scale * qtol
413-
419+
414420
Tester._assert_outputs_equal(
415421
stage_output,
416422
reference_output,

backends/test/suite/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Test run context management. This is used to determine the test context for reporting
22
# purposes.
33
class TestContext:
4-
def __init__(self, test_name: str, test_base_name: str, flow_name: str, params: dict | None):
4+
def __init__(
5+
self, test_name: str, test_base_name: str, flow_name: str, params: dict | None
6+
):
57
self.test_name = test_name
68
self.test_base_name = test_base_name
79
self.flow_name = flow_name

backends/test/suite/flow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def all_flows() -> dict[str, TestFlow]:
6464

6565
try:
6666
from executorch.backends.test.suite.flows.vulkan import VULKAN_TEST_FLOW
67+
6768
flows += [
6869
VULKAN_TEST_FLOW,
6970
]
@@ -72,6 +73,7 @@ def all_flows() -> dict[str, TestFlow]:
7273

7374
try:
7475
from executorch.backends.test.suite.flows.qualcomm import QUALCOMM_TEST_FLOW
76+
7577
flows += [
7678
QUALCOMM_TEST_FLOW,
7779
]
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from executorch.backends.qualcomm.tests.tester import QualcommTester
22
from executorch.backends.test.suite.flow import TestFlow
33

4+
45
def _create_qualcomm_flow(
5-
name: str,
6-
quantize: bool = False,
6+
name: str,
7+
quantize: bool = False,
78
) -> TestFlow:
89
return TestFlow(
910
name,
@@ -12,4 +13,5 @@ def _create_qualcomm_flow(
1213
quantize=quantize,
1314
)
1415

16+
1517
QUALCOMM_TEST_FLOW = _create_qualcomm_flow("qualcomm")
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from executorch.backends.vulkan.test.tester import VulkanTester
21
from executorch.backends.test.suite.flow import TestFlow
2+
from executorch.backends.vulkan.test.tester import VulkanTester
3+
34

45
def _create_vulkan_flow(
5-
name: str,
6-
quantize: bool = False,
6+
name: str,
7+
quantize: bool = False,
78
) -> TestFlow:
89
return TestFlow(
910
name,
@@ -12,4 +13,5 @@ def _create_vulkan_flow(
1213
quantize=quantize,
1314
)
1415

16+
1517
VULKAN_TEST_FLOW = _create_vulkan_flow("vulkan")

backends/test/suite/operators/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def _create_test_for_backend(
117117

118118
if test_type == TestType.STANDARD:
119119
test_name = f"{test_func.__name__}_{flow.name}"
120-
wrapped_test = _make_wrapped_test(test_func, test_name, test_func.__name__, flow)
120+
wrapped_test = _make_wrapped_test(
121+
test_func, test_name, test_func.__name__, flow
122+
)
121123
setattr(cls, test_name, wrapped_test)
122124
elif test_type == TestType.DTYPE:
123125
for dtype in DTYPES:

backends/test/suite/reporting.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
import csv
12
from collections import Counter
23
from dataclasses import dataclass
34
from enum import IntEnum
45
from functools import reduce
56
from typing import TextIO
67

7-
import csv
8-
98
from executorch.backends.test.harness.error_statistics import ErrorStatistics
109

10+
1111
class TestResult(IntEnum):
1212
"""Represents the result of a test case run, indicating success or a specific failure reason."""
1313

@@ -80,13 +80,13 @@ class TestCaseSummary:
8080
"""
8181
Contains summary results for the execution of a single test case.
8282
"""
83-
83+
8484
backend: str
8585
""" The name of the target backend. """
8686

8787
base_name: str
8888
""" The base name of the test, not including flow or parameter suffixes. """
89-
89+
9090
flow: str
9191
""" The backend-specific flow name. Corresponds to flows registered in backends/test/suite/__init__.py. """
9292

@@ -101,7 +101,7 @@ class TestCaseSummary:
101101

102102
error: Exception | None
103103
""" The Python exception object, if any. """
104-
104+
105105
tensor_error_statistics: list[ErrorStatistics]
106106
"""
107107
Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
@@ -180,8 +180,9 @@ def complete_test_session() -> RunSummary:
180180

181181
return summary
182182

183+
183184
def generate_csv_report(summary: RunSummary, output: TextIO):
184-
""" Write a run summary report to a file in CSV format. """
185+
"""Write a run summary report to a file in CSV format."""
185186

186187
field_names = [
187188
"Test ID",
@@ -190,30 +191,38 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
190191
"Flow",
191192
"Result",
192193
]
193-
194+
194195
# Tests can have custom parameters. We'll want to report them here, so we need
195196
# a list of all unique parameter names.
196197
param_names = reduce(
197198
lambda a, b: a.union(b),
198-
(set(s.params.keys()) for s in summary.test_case_summaries if s.params is not None),
199-
set()
199+
(
200+
set(s.params.keys())
201+
for s in summary.test_case_summaries
202+
if s.params is not None
203+
),
204+
set(),
200205
)
201206
field_names += (s.capitalize() for s in param_names)
202207

203208
# Add tensor error statistic field names for each output index.
204-
max_outputs = max(len(s.tensor_error_statistics) for s in summary.test_case_summaries)
209+
max_outputs = max(
210+
len(s.tensor_error_statistics) for s in summary.test_case_summaries
211+
)
205212
for i in range(max_outputs):
206-
field_names.extend([
207-
f"Output {i} Error Max",
208-
f"Output {i} Error MAE",
209-
f"Output {i} Error MSD",
210-
f"Output {i} Error L2",
211-
f"Output {i} SQNR",
212-
])
213+
field_names.extend(
214+
[
215+
f"Output {i} Error Max",
216+
f"Output {i} Error MAE",
217+
f"Output {i} Error MSD",
218+
f"Output {i} Error L2",
219+
f"Output {i} SQNR",
220+
]
221+
)
213222

214223
writer = csv.DictWriter(output, field_names)
215224
writer.writeheader()
216-
225+
217226
for record in summary.test_case_summaries:
218227
row = {
219228
"Test ID": record.name,
@@ -223,10 +232,8 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
223232
"Result": record.result.display_name(),
224233
}
225234
if record.params is not None:
226-
row.update({
227-
k.capitalize(): v for k, v in record.params.items()
228-
})
229-
235+
row.update({k.capitalize(): v for k, v in record.params.items()})
236+
230237
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
231238
row[f"Output {output_idx} Error Max"] = error_stats.error_max
232239
row[f"Output {output_idx} Error MAE"] = error_stats.error_mae

0 commit comments

Comments
 (0)