Skip to content

Commit 931337a

Browse files
committed
[Backend Tester] Add Vulkan tester and register test flow
ghstack-source-id: 6858b2d ghstack-comment-id: 3105185112 Pull-Request: #12738
1 parent 0b8d99f commit 931337a

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

backends/test/suite/flow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,13 @@ def all_flows() -> dict[str, TestFlow]:
6262
except Exception as e:
6363
logger.info(f"Skipping Core ML flow registration: {e}")
6464

65+
try:
66+
from executorch.backends.test.suite.flows.vulkan import VULKAN_TEST_FLOW
67+
68+
flows += [
69+
VULKAN_TEST_FLOW,
70+
]
71+
except Exception as e:
72+
logger.info(f"Skipping Vulkan flow registration: {e}")
73+
6574
return {f.name: f for f in flows if f is not None}

backends/test/suite/flows/vulkan.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from executorch.backends.test.suite.flow import TestFlow
2+
from executorch.backends.vulkan.test.tester import VulkanTester
3+
4+
5+
def _create_vulkan_flow(
6+
name: str,
7+
quantize: bool = False,
8+
) -> TestFlow:
9+
return TestFlow(
10+
name,
11+
backend="vulkan",
12+
tester_factory=VulkanTester,
13+
quantize=quantize,
14+
)
15+
16+
17+
VULKAN_TEST_FLOW = _create_vulkan_flow("vulkan")

backends/vulkan/test/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
23

34
oncall("executorch")
45

@@ -57,3 +58,12 @@ python_unittest(
5758
"//executorch/backends/vulkan:vulkan_preprocess",
5859
],
5960
)
61+
62+
runtime.python_library(
63+
name = "tester",
64+
srcs = ["tester.py"],
65+
deps = [
66+
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
67+
"//executorch/backends/vulkan:vulkan_preprocess",
68+
]
69+
)

backends/vulkan/test/tester.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, List, Optional, Tuple
8+
9+
import executorch
10+
import executorch.backends.test.harness.stages as BaseStages
11+
12+
import torch
13+
from executorch.backends.test.harness import Tester as TesterBase
14+
from executorch.backends.test.harness.stages import StageType
15+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
16+
from executorch.exir import EdgeCompileConfig
17+
from executorch.exir.backend.partitioner import Partitioner
18+
19+
20+
class Partition(BaseStages.Partition):
21+
def __init__(self, partitioner: Optional[Partitioner] = None):
22+
super().__init__(
23+
partitioner=partitioner or VulkanPartitioner(),
24+
)
25+
26+
27+
class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower):
28+
def __init__(
29+
self,
30+
partitioners: Optional[List[Partitioner]] = None,
31+
edge_compile_config: Optional[EdgeCompileConfig] = None,
32+
):
33+
super().__init__(
34+
default_partitioner_cls=VulkanPartitioner,
35+
partitioners=partitioners,
36+
edge_compile_config=edge_compile_config
37+
or EdgeCompileConfig(_check_ir_validity=False),
38+
)
39+
40+
41+
class VulkanTester(TesterBase):
42+
def __init__(
43+
self,
44+
module: torch.nn.Module,
45+
example_inputs: Tuple[torch.Tensor],
46+
dynamic_shapes: Optional[Tuple[Any]] = None,
47+
):
48+
stage_classes = (
49+
executorch.backends.test.harness.Tester.default_stage_classes()
50+
| {
51+
StageType.PARTITION: Partition,
52+
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
53+
}
54+
)
55+
56+
super().__init__(
57+
module=module,
58+
stage_classes=stage_classes,
59+
example_inputs=example_inputs,
60+
dynamic_shapes=dynamic_shapes,
61+
)

0 commit comments

Comments
 (0)