Skip to content

Commit 5cdfc1f

Browse files
committed
[Backend Tester] Add Vulkan tester and register test flow
ghstack-source-id: ae74ff8 ghstack-comment-id: 3105185112 Pull-Request: #12738
1 parent 654b0af commit 5cdfc1f

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed

backends/test/suite/flow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,12 @@ def all_flows() -> dict[str, TestFlow]:
6262
except Exception as e:
6363
logger.info(f"Skipping Core ML flow registration: {e}")
6464

65+
try:
66+
from executorch.backends.test.suite.flows.vulkan import VULKAN_TEST_FLOW
67+
flows += [
68+
VULKAN_TEST_FLOW,
69+
]
70+
except Exception as e:
71+
logger.info(f"Skipping Vulkan flow registration: {e}")
72+
6573
return {f.name: f for f in flows if f is not None}

backends/test/suite/flows/vulkan.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from executorch.backends.vulkan.test.tester import VulkanTester
2+
from executorch.backends.test.suite.flow import TestFlow
3+
4+
def _create_vulkan_flow(
5+
name: str,
6+
quantize: bool = False,
7+
) -> TestFlow:
8+
return TestFlow(
9+
name,
10+
backend="vulkan",
11+
tester_factory=VulkanTester,
12+
quantize=quantize,
13+
)
14+
15+
VULKAN_TEST_FLOW = _create_vulkan_flow("vulkan")

backends/vulkan/test/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
23

34
oncall("executorch")
45

@@ -57,3 +58,12 @@ python_unittest(
5758
"//executorch/backends/vulkan:vulkan_preprocess",
5859
],
5960
)
61+
62+
runtime.python_library(
63+
name = "tester",
64+
srcs = ["tester.py"],
65+
deps = [
66+
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
67+
"//executorch/backends/vulkan:vulkan_preprocess",
68+
]
69+
)

backends/vulkan/test/tester.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, List, Optional, Tuple
8+
9+
import executorch
10+
import executorch.backends.test.harness.stages as BaseStages
11+
12+
import torch
13+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
14+
from executorch.backends.test.harness import Tester as TesterBase
15+
from executorch.backends.test.harness.stages import StageType
16+
from executorch.exir import EdgeCompileConfig
17+
from executorch.exir.backend.partitioner import Partitioner
18+
19+
20+
class Partition(BaseStages.Partition):
21+
def __init__(self, partitioner: Optional[Partitioner] = None):
22+
super().__init__(
23+
partitioner=partitioner or VulkanPartitioner(),
24+
)
25+
26+
27+
class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower):
28+
def __init__(
29+
self,
30+
partitioners: Optional[List[Partitioner]] = None,
31+
edge_compile_config: Optional[EdgeCompileConfig] = None,
32+
):
33+
super().__init__(
34+
default_partitioner_cls=VulkanPartitioner,
35+
partitioners=partitioners,
36+
edge_compile_config=edge_compile_config or EdgeCompileConfig(_check_ir_validity=False),
37+
)
38+
39+
40+
class VulkanTester(TesterBase):
41+
def __init__(
42+
self,
43+
module: torch.nn.Module,
44+
example_inputs: Tuple[torch.Tensor],
45+
dynamic_shapes: Optional[Tuple[Any]] = None,
46+
):
47+
stage_classes = (
48+
executorch.backends.test.harness.Tester.default_stage_classes()
49+
| {
50+
StageType.PARTITION: Partition,
51+
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
52+
}
53+
)
54+
55+
super().__init__(
56+
module=module,
57+
stage_classes=stage_classes,
58+
example_inputs=example_inputs,
59+
dynamic_shapes=dynamic_shapes,
60+
)

0 commit comments

Comments
 (0)