diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py index d36c7c85a32..983d2c82bd0 100644 --- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -12,7 +12,7 @@ from executorch.backends.vulkan.test.op_tests.utils.gen_correctness_base import ( CorrectnessTestGen, ) -from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite +from executorch.backends.vulkan.test.op_tests.utils.test_suite import VkTestSuite from torchgen.model import NativeFunction @@ -72,10 +72,12 @@ class GeneratedOpBenchmark_{op_name} : public ::benchmark::Fixture {{ class VkBenchmarkGen(CorrectnessTestGen): - def __init__(self, op_reg_name: str, f: NativeFunction, inputs: TestSuite): + def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite): super().__init__(f, inputs) self.op_reg_name = op_reg_name - self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def) + self.generator = ComputeGraphGen( + self.op_reg_name, self.f, self.suite_def, inputs.force_io + ) def gen_call_benchmark(self, prepack=False) -> str: test_str = f"benchmark_{self.op_name}(" @@ -197,7 +199,7 @@ def generate_benchmark_fixture(self) -> str: float high = 1.0) {{ if (high == 1.0 && low == 0.0) return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); - + if (dtype == at::kChar) return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype)); diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index 6f93e662076..708da8eab85 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -90,10 +90,17 @@ def vk_out(self): class ComputeGraphGen: backend_key = None - def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): + def __init__( + self, + op_reg_name: str, + f: NativeFunction, + suite_def: TestSuite, + include_io: bool = True, + ): self.op_reg_name = op_reg_name self.f = f self.suite_def = suite_def + self.include_io = include_io self.f_sig = CppSignatureGroup.from_native_function( self.f, method=False, fallback_binding=self.f.manual_cpp_binding @@ -275,6 +282,10 @@ def create_value_for( # noqa: C901 prepack = self.prepack_ref(ref) ref_is_view = self.suite_def.is_view_op and ref.is_out + # If skipping IO, force is_in to be False + if not self.include_io and ref.is_in: + ref.is_in = False + cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef" if not include_declarations: cpp_type = "" @@ -602,7 +613,8 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str: graph_build += self.create_value_for(self.refs["out"], include_declarations) graph_build += self.create_op_call() - graph_build += self.set_output(self.refs["out"], include_declarations) + if self.include_io: + graph_build += self.set_output(self.refs["out"], include_declarations) graph_build += f"{self.graph}{self.dot}prepare();\n" graph_build += f"{self.graph}{self.dot}encode_prepack();\n" @@ -614,18 +626,22 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str: def gen_graph_exec_code(self, check_output=True) -> str: graph_exec = "" - for aten_arg in self.args: - ref = self.refs[aten_arg.name] - if ref.is_in: - graph_exec += self.virtual_resize(ref) - graph_exec += self.copy_into_staging(ref) + if self.include_io: + for aten_arg in self.args: + ref = self.refs[aten_arg.name] + if ref.is_in: + graph_exec += self.virtual_resize(ref) + graph_exec += self.copy_into_staging(ref) + + graph_exec += f"{self.graph}{self.dot}propagate_resize();\n" - graph_exec += f"{self.graph}{self.dot}propagate_resize();\n" graph_exec += f"{self.graph}{self.dot}execute();\n" graph_exec += self.declare_vk_out_for(self.refs["out"]) - graph_exec += self.copy_from_staging(self.refs["out"]) - if check_output: + if self.include_io: + graph_exec += self.copy_from_staging(self.refs["out"]) + + if self.include_io and check_output: graph_exec += self.check_graph_out(self.refs["out"]) graph_exec = re.sub(r"^", " ", graph_exec, flags=re.M) diff --git a/backends/vulkan/test/op_tests/utils/test_suite.py b/backends/vulkan/test/op_tests/utils/test_suite.py index dd01bdde3a4..72ba457b5af 100644 --- a/backends/vulkan/test/op_tests/utils/test_suite.py +++ b/backends/vulkan/test/op_tests/utils/test_suite.py @@ -47,3 +47,4 @@ def __init__(self, input_cases: List[Any]): self.storage_types: List[str] = ["utils::kTexture3D"] self.layouts: List[str] = ["utils::kChannelsPacked"] self.data_gen: str = "make_rand_tensor" + self.force_io: bool = True