Skip to content

Commit 32f54b0

Browse files
committed
Update
[ghstack-poisoned]
2 parents de21ac2 + 710ea49 commit 32f54b0

File tree

10 files changed

+57
-35
lines changed

10 files changed

+57
-35
lines changed

backends/qualcomm/tests/tester.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
1414
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
1515
from executorch.backends.qualcomm.utils.utils import (
16-
generate_qnn_executorch_compiler_spec,
1716
generate_htp_compiler_spec,
17+
generate_qnn_executorch_compiler_spec,
1818
get_soc_to_chipset_map,
1919
)
2020
from executorch.backends.test.harness import Tester as TesterBase
@@ -36,7 +36,7 @@ def __init__(
3636
self,
3737
partitioners: Optional[List[Partitioner]] = None,
3838
edge_compile_config: Optional[EdgeCompileConfig] = None,
39-
soc_model: str = "SM8650"
39+
soc_model: str = "SM8650",
4040
):
4141
backend_options = generate_htp_compiler_spec(use_fp16=True)
4242
self.chipset = get_soc_to_chipset_map()[soc_model]
@@ -47,7 +47,8 @@ def __init__(
4747

4848
super().__init__(
4949
partitioners=partitioners or [QnnPartitioner(self.compiler_specs)],
50-
edge_compile_config=edge_compile_config or EdgeCompileConfig(_check_ir_validity=False),
50+
edge_compile_config=edge_compile_config
51+
or EdgeCompileConfig(_check_ir_validity=False),
5152
default_partitioner_cls=QnnPartitioner,
5253
)
5354

@@ -69,7 +70,7 @@ def __init__(
6970
module: torch.nn.Module,
7071
example_inputs: Tuple[torch.Tensor],
7172
dynamic_shapes: Optional[Tuple[Any]] = None,
72-
):
73+
):
7374
# Specialize for Qualcomm
7475
stage_classes = (
7576
executorch.backends.test.harness.Tester.default_stage_classes()

backends/test/suite/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Test run context management. This is used to determine the test context for reporting
22
# purposes.
33
class TestContext:
4-
def __init__(self, test_name: str, test_base_name: str, flow_name: str, params: dict | None):
4+
def __init__(
5+
self, test_name: str, test_base_name: str, flow_name: str, params: dict | None
6+
):
57
self.test_name = test_name
68
self.test_base_name = test_base_name
79
self.flow_name = flow_name

backends/test/suite/flow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def all_flows() -> dict[str, TestFlow]:
6464

6565
try:
6666
from executorch.backends.test.suite.flows.vulkan import VULKAN_TEST_FLOW
67+
6768
flows += [
6869
VULKAN_TEST_FLOW,
6970
]
@@ -72,6 +73,7 @@ def all_flows() -> dict[str, TestFlow]:
7273

7374
try:
7475
from executorch.backends.test.suite.flows.qualcomm import QUALCOMM_TEST_FLOW
76+
7577
flows += [
7678
QUALCOMM_TEST_FLOW,
7779
]
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from executorch.backends.qualcomm.tests.tester import QualcommTester
22
from executorch.backends.test.suite.flow import TestFlow
33

4+
45
def _create_qualcomm_flow(
5-
name: str,
6-
quantize: bool = False,
6+
name: str,
7+
quantize: bool = False,
78
) -> TestFlow:
89
return TestFlow(
910
name,
@@ -12,4 +13,5 @@ def _create_qualcomm_flow(
1213
quantize=quantize,
1314
)
1415

16+
1517
QUALCOMM_TEST_FLOW = _create_qualcomm_flow("qualcomm")
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from executorch.backends.vulkan.test.tester import VulkanTester
21
from executorch.backends.test.suite.flow import TestFlow
2+
from executorch.backends.vulkan.test.tester import VulkanTester
3+
34

45
def _create_vulkan_flow(
5-
name: str,
6-
quantize: bool = False,
6+
name: str,
7+
quantize: bool = False,
78
) -> TestFlow:
89
return TestFlow(
910
name,
@@ -12,4 +13,5 @@ def _create_vulkan_flow(
1213
quantize=quantize,
1314
)
1415

16+
1517
VULKAN_TEST_FLOW = _create_vulkan_flow("vulkan")

backends/test/suite/operators/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def _create_test_for_backend(
117117

118118
if test_type == TestType.STANDARD:
119119
test_name = f"{test_func.__name__}_{flow.name}"
120-
wrapped_test = _make_wrapped_test(test_func, test_name, test_func.__name__, flow)
120+
wrapped_test = _make_wrapped_test(
121+
test_func, test_name, test_func.__name__, flow
122+
)
121123
setattr(cls, test_name, wrapped_test)
122124
elif test_type == TestType.DTYPE:
123125
for dtype in DTYPES:

backends/test/suite/reporting.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
import csv
12
from collections import Counter
23
from dataclasses import dataclass
34
from enum import IntEnum
45
from functools import reduce
5-
from re import A
66
from typing import TextIO
77

8-
import csv
98

109
class TestResult(IntEnum):
1110
"""Represents the result of a test case run, indicating success or a specific failure reason."""
@@ -79,13 +78,13 @@ class TestCaseSummary:
7978
"""
8079
Contains summary results for the execution of a single test case.
8180
"""
82-
81+
8382
backend: str
8483
""" The name of the target backend. """
8584

8685
base_name: str
8786
""" The base name of the test, not including flow or parameter suffixes. """
88-
87+
8988
flow: str
9089
""" The backend-specific flow name. Corresponds to flows registered in backends/test/suite/__init__.py. """
9190

@@ -173,8 +172,9 @@ def complete_test_session() -> RunSummary:
173172

174173
return summary
175174

175+
176176
def generate_csv_report(summary: RunSummary, output: TextIO):
177-
""" Write a run summary report to a file in CSV format. """
177+
"""Write a run summary report to a file in CSV format."""
178178

179179
field_names = [
180180
"Test ID",
@@ -183,19 +183,23 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
183183
"Flow",
184184
"Result",
185185
]
186-
186+
187187
# Tests can have custom parameters. We'll want to report them here, so we need
188188
# a list of all unique parameter names.
189189
param_names = reduce(
190190
lambda a, b: a.union(b),
191-
(set(s.params.keys()) for s in summary.test_case_summaries if s.params is not None),
192-
set()
191+
(
192+
set(s.params.keys())
193+
for s in summary.test_case_summaries
194+
if s.params is not None
195+
),
196+
set(),
193197
)
194198
field_names += (s.capitalize() for s in param_names)
195199

196200
writer = csv.DictWriter(output, field_names)
197201
writer.writeheader()
198-
202+
199203
for record in summary.test_case_summaries:
200204
row = {
201205
"Test ID": record.name,
@@ -205,7 +209,5 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
205209
"Result": record.result.display_name(),
206210
}
207211
if record.params is not None:
208-
row.update({
209-
k.capitalize(): v for k, v in record.params.items()
210-
})
212+
row.update({k.capitalize(): v for k, v in record.params.items()})
211213
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ def parse_args():
173173
"-f", "--filter", nargs="?", help="A regular expression filter for test names."
174174
)
175175
parser.add_argument(
176-
"-r", "--report", nargs="?", help="A file to write the test report to, in CSV format."
176+
"-r",
177+
"--report",
178+
nargs="?",
179+
help="A file to write the test report to, in CSV format.",
177180
)
178181
return parser.parse_args()
179182

@@ -202,7 +205,7 @@ def runner_main():
202205

203206
summary = complete_test_session()
204207
print_summary(summary)
205-
208+
206209
if args.report is not None:
207210
with open(args.report, "w") as f:
208211
print(f"Writing CSV report to {args.report}.")

backends/test/suite/tests/test_reporting.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
import torch
21
import unittest
32

43
from csv import DictReader
5-
from ..reporting import TestResult, TestCaseSummary, RunSummary, TestSessionState, generate_csv_report
64
from io import StringIO
75

6+
import torch
7+
8+
from ..reporting import (
9+
generate_csv_report,
10+
RunSummary,
11+
TestCaseSummary,
12+
TestResult,
13+
TestSessionState,
14+
)
15+
816
# Test data for simulated test results.
917
TEST_CASE_SUMMARIES = [
1018
TestCaseSummary(
@@ -45,16 +53,17 @@
4553
),
4654
]
4755

56+
4857
class Reporting(unittest.TestCase):
4958
def test_csv_report_simple(self):
5059
# Verify the format of a simple CSV run report.
5160
session_state = TestSessionState()
5261
session_state.test_case_summaries.extend(TEST_CASE_SUMMARIES)
5362
run_summary = RunSummary.from_session(session_state)
54-
63+
5564
strio = StringIO()
5665
generate_csv_report(run_summary, strio)
57-
66+
5867
# Attempt to deserialize and validate the CSV report.
5968
report = DictReader(StringIO(strio.getvalue()))
6069
records = list(report)
@@ -95,7 +104,3 @@ def test_csv_report_simple(self):
95104
self.assertEqual(records[3]["Result"], "Fail (Export)")
96105
self.assertEqual(records[3]["Dtype"], "")
97106
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
98-
99-
100-
101-

backends/vulkan/test/tester.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import executorch.backends.test.harness.stages as BaseStages
1111

1212
import torch
13-
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
1413
from executorch.backends.test.harness import Tester as TesterBase
1514
from executorch.backends.test.harness.stages import StageType
15+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
1616
from executorch.exir import EdgeCompileConfig
1717
from executorch.exir.backend.partitioner import Partitioner
1818

@@ -33,7 +33,8 @@ def __init__(
3333
super().__init__(
3434
default_partitioner_cls=VulkanPartitioner,
3535
partitioners=partitioners,
36-
edge_compile_config=edge_compile_config or EdgeCompileConfig(_check_ir_validity=False),
36+
edge_compile_config=edge_compile_config
37+
or EdgeCompileConfig(_check_ir_validity=False),
3738
)
3839

3940

0 commit comments

Comments
 (0)