Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch.backends.vulkan.test.op_tests.utils.gen_correctness_base import (
CorrectnessTestGen,
)
from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
from executorch.backends.vulkan.test.op_tests.utils.test_suite import VkTestSuite

from torchgen.model import NativeFunction

Expand Down Expand Up @@ -72,10 +72,12 @@ class GeneratedOpBenchmark_{op_name} : public ::benchmark::Fixture {{


class VkBenchmarkGen(CorrectnessTestGen):
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: TestSuite):
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite):
super().__init__(f, inputs)
self.op_reg_name = op_reg_name
self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def)
self.generator = ComputeGraphGen(
self.op_reg_name, self.f, self.suite_def, inputs.force_io
)

def gen_call_benchmark(self, prepack=False) -> str:
test_str = f"benchmark_{self.op_name}("
Expand Down Expand Up @@ -197,7 +199,7 @@ def generate_benchmark_fixture(self) -> str:
float high = 1.0) {{
if (high == 1.0 && low == 0.0)
return at::rand(sizes, at::device(at::kCPU).dtype(dtype));

if (dtype == at::kChar)
return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype));

Expand Down
36 changes: 26 additions & 10 deletions backends/vulkan/test/op_tests/utils/gen_computegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,17 @@ def vk_out(self):
class ComputeGraphGen:
backend_key = None

def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
def __init__(
self,
op_reg_name: str,
f: NativeFunction,
suite_def: TestSuite,
include_io: bool = True,
):
self.op_reg_name = op_reg_name
self.f = f
self.suite_def = suite_def
self.include_io = include_io

self.f_sig = CppSignatureGroup.from_native_function(
self.f, method=False, fallback_binding=self.f.manual_cpp_binding
Expand Down Expand Up @@ -275,6 +282,10 @@ def create_value_for( # noqa: C901
prepack = self.prepack_ref(ref)
ref_is_view = self.suite_def.is_view_op and ref.is_out

# If skipping IO, force is_in to be False
if not self.include_io and ref.is_in:
ref.is_in = False

cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"
if not include_declarations:
cpp_type = ""
Expand Down Expand Up @@ -602,7 +613,8 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str:
graph_build += self.create_value_for(self.refs["out"], include_declarations)
graph_build += self.create_op_call()

graph_build += self.set_output(self.refs["out"], include_declarations)
if self.include_io:
graph_build += self.set_output(self.refs["out"], include_declarations)

graph_build += f"{self.graph}{self.dot}prepare();\n"
graph_build += f"{self.graph}{self.dot}encode_prepack();\n"
Expand All @@ -614,18 +626,22 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str:

def gen_graph_exec_code(self, check_output=True) -> str:
graph_exec = ""
for aten_arg in self.args:
ref = self.refs[aten_arg.name]
if ref.is_in:
graph_exec += self.virtual_resize(ref)
graph_exec += self.copy_into_staging(ref)
if self.include_io:
for aten_arg in self.args:
ref = self.refs[aten_arg.name]
if ref.is_in:
graph_exec += self.virtual_resize(ref)
graph_exec += self.copy_into_staging(ref)

graph_exec += f"{self.graph}{self.dot}propagate_resize();\n"

graph_exec += f"{self.graph}{self.dot}propagate_resize();\n"
graph_exec += f"{self.graph}{self.dot}execute();\n"

graph_exec += self.declare_vk_out_for(self.refs["out"])
graph_exec += self.copy_from_staging(self.refs["out"])
if check_output:
if self.include_io:
graph_exec += self.copy_from_staging(self.refs["out"])

if self.include_io and check_output:
graph_exec += self.check_graph_out(self.refs["out"])

graph_exec = re.sub(r"^", " ", graph_exec, flags=re.M)
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/utils/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ def __init__(self, input_cases: List[Any]):
self.storage_types: List[str] = ["utils::kTexture3D"]
self.layouts: List[str] = ["utils::kChannelsPacked"]
self.data_gen: str = "make_rand_tensor"
self.force_io: bool = True
Loading