diff --git a/backends/vulkan/test/op_tests/generate_op_benchmarks.py b/backends/vulkan/test/op_tests/generate_op_benchmarks.py new file mode 100644 index 00000000000..2eeb209b9d2 --- /dev/null +++ b/backends/vulkan/test/op_tests/generate_op_benchmarks.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from typing import Dict + +from executorch.backends.vulkan.test.op_tests.cases import test_suites + +from executorch.backends.vulkan.test.op_tests.utils.gen_benchmark_vk import ( + VkBenchmarkFileGen, +) +from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import ( + ComputeGraphGen, +) +from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite +from torchgen import local + +from torchgen.gen import parse_native_yaml, ParsedYaml +from torchgen.model import DispatchKey, NativeFunction + + +def registry_name(f: NativeFunction) -> str: + name = str(f.namespace) + "." + str(f.func.name) + if len(f.func.name.overload_name) == 0: + name += ".default" + return name + + +def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]: + f_map: Dict[str, NativeFunction] = {} + for f in parsed_yaml.native_functions: + f_map[registry_name(f)] = f + return f_map + + +def process_test_suites( + cpp_generator: VkBenchmarkFileGen, + f_map: Dict[str, NativeFunction], + test_suites: Dict[str, TestSuite], +) -> None: + for registry_name, op_test_suite in test_suites.items(): + f = f_map[registry_name] + cpp_generator.add_suite(registry_name, f, op_test_suite) + + +@local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False +) +def generate_cpp( + native_functions_yaml_path: str, tags_path: str, output_dir: str +) -> None: + output_file = os.path.join(output_dir, "op_benchmarks.cpp") + cpp_generator = VkBenchmarkFileGen(output_file) + + parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path) + f_map = construct_f_map(parsed_yaml) + + ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU] + + process_test_suites(cpp_generator, f_map, test_suites) + + with open(output_file, "w") as file: + file.write(cpp_generator.generate_cpp()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--aten-yaml-path", + help="path to native_functions.yaml file.", + ) + parser.add_argument( + "--tags-path", + help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.", + ) + + parser.add_argument("-o", "--output", help="Output directory", required=True) + args = parser.parse_args() + generate_cpp(args.aten_yaml_path, args.tags_path, args.output) diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 8d7d80e6743..9b6ea61de21 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -20,6 +20,19 @@ def define_common_targets(is_fbcode = False): external_deps = ["torchgen"], ) + runtime.python_library( + name = "generate_op_benchmarks_lib", + srcs = native.glob(["utils/*.py"]) + [ + "generate_op_benchmarks.py", + "cases.py", + ], + base_module = "executorch.backends.vulkan.test.op_tests", + deps = [ + "fbsource//third-party/pypi/expecttest:expecttest", + ], + external_deps = ["torchgen"], + ) + runtime.python_binary( name = "generate_op_correctness_tests", main_module = "executorch.backends.vulkan.test.op_tests.generate_op_correctness_tests", @@ -28,6 +41,14 @@ def define_common_targets(is_fbcode = False): ], ) + runtime.python_binary( + name = "generate_op_benchmarks", + main_module = "executorch.backends.vulkan.test.op_tests.generate_op_benchmarks", + deps = [ + ":generate_op_benchmarks_lib", + ], + ) + aten_src_path = runtime.external_dep_location("aten-src-path") genrule_cmd = [ "$(exe :generate_op_correctness_tests)", @@ -45,6 +66,22 @@ def define_common_targets(is_fbcode = False): default_outs = ["."], ) + benchmarks_genrule_cmd = [ + "$(exe :generate_op_benchmarks)", + "--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path), + "--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path), + "-o $OUT", + ] + + runtime.genrule( + name = "generated_op_benchmarks_cpp", + outs = { + "op_benchmarks.cpp": ["op_benchmarks.cpp"], + }, + cmd = " ".join(benchmarks_genrule_cmd), + default_outs = ["."], + ) + pt_operator_library( name = "all_aten_ops", check_decl = False, @@ -76,6 +113,22 @@ def define_common_targets(is_fbcode = False): ], ) + runtime.cxx_binary( + name = "compute_graph_op_benchmarks_bin", + srcs = [ + ":generated_op_benchmarks_cpp[op_benchmarks.cpp]", + ], + compiler_flags = [ + "-Wno-unused-variable", + ], + define_static_target = False, + deps = [ + "//third-party/benchmark:benchmark", + "//executorch/backends/vulkan:vulkan_graph_runtime", + ":all_aten_ops_lib", + ], + ) + runtime.cxx_test( name = "compute_graph_op_tests", srcs = [ diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py new file mode 100644 index 00000000000..fb42d982f67 --- /dev/null +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -0,0 +1,335 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re + +from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import ( + ComputeGraphGen, +) +from executorch.backends.vulkan.test.op_tests.utils.gen_correctness_base import ( + CorrectnessTestGen, +) +from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite + +from torchgen.model import NativeFunction + +########################## +## Test Suite Generation ## +########################## + +benchmark_fixture_template = """ +class GeneratedOpBenchmark_{op_name} : public ::benchmark::Fixture {{ + protected: + ComputeGraph* graph; + at::ScalarType test_dtype = at::kFloat; + float rtol = {rtol}; + float atol = {atol}; + + {arg_valuerefs} + + void SetUp(::benchmark::State& state) override {{ + GraphConfig config; + config.descriptor_pool_safety_factor = 2.0; + test_dtype = at::ScalarType(state.range(0)); + const utils::StorageType storage_type = utils::StorageType(state.range(1)); + const utils::GPUMemoryLayout memory_layout = utils::GPUMemoryLayout(state.range(2)); + config.set_storage_type_override(storage_type); + config.set_memory_layout_override(memory_layout); + config.enable_querypool = true; + graph = new ComputeGraph(config); + }} + + void TearDown(::benchmark::State& state) override {{ + delete graph; + graph = nullptr; + }} + + {build_graph_fn} + {benchmark_fn} +}}; +""" + +benchmark_template = """ +BENCHMARK_DEFINE_F(GeneratedOpBenchmark_{op_name}, {case_name})(benchmark::State& state) {{ + {skips} + {create_ref_data} + {call_build_graph} + ShaderTimes shader_times; + for (auto _ : state) {{ + {call_benchmark} + graph->context()->querypool().extract_results(); + QueryPoolResults results = graph->context()->querypool().get_shader_timestamp_data(); + process_querypool_results(results, shader_times); + }} + register_shader_time_counters(state, shader_times); +}} + +BENCHMARK_REGISTER_F(GeneratedOpBenchmark_{op_name}, {case_name})->Threads(1)->ArgsProduct({combos}); +""" + + +class VkBenchmarkGen(CorrectnessTestGen): + def __init__(self, op_reg_name: str, f: NativeFunction, inputs: TestSuite): + super().__init__(f, inputs) + self.op_reg_name = op_reg_name + self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def) + + def gen_call_benchmark(self, prepack=False) -> str: + test_str = f"benchmark_{self.op_name}(" + if prepack: + test_str = f"prepacked_benchmark_{self.op_name}(" + for binding in self.f_sig.arguments(): + arg = binding.argument + test_str += f"{arg.name}, " + test_str = test_str[:-2] + ");" + test_str = re.sub(r"^", " ", test_str, flags=re.M) + return test_str + + def gen_call_build_graph(self, prepack=False) -> str: + test_str = f"build_graph_{self.op_name}(" + if prepack: + test_str = f"prepacked_build_graph_{self.op_name}(" + for binding in self.f_sig.arguments(): + arg = binding.argument + test_str += f"{arg.name}, " + test_str = test_str[:-2] + ");" + test_str = re.sub(r"^", " ", test_str, flags=re.M) + return test_str + + def gen_combos(self, inputs) -> str: + dtypes_list = ", ".join(f"int({dtype})" for dtype in self.suite_def.dtypes) + storage_types_list = ", ".join( + f"int({storage_type})" for storage_type in self.suite_def.storage_types + ) + layouts_list = ", ".join(f"int({layout})" for layout in self.suite_def.layouts) + return f"{{ {{ {dtypes_list} }}, {{ {storage_types_list} }}, {{ {layouts_list} }} }}" + + def generate_benchmark_case(self, inputs, prepack=False) -> str: + return benchmark_template.format( + op_name=f"{self.op_name}", + case_name=self.gen_case_name(inputs, prepack), + skips=self.generator.gen_conditional_skips( + 'state.SkipWithError("unsupported type"); return;' + ), + create_ref_data=self.gen_create_ref_data(inputs), + call_build_graph=self.gen_call_build_graph(prepack), + call_benchmark=self.gen_call_benchmark(prepack), + combos=self.gen_combos(inputs), + ) + + def generate_benchmark(self) -> str: + benchmarks_cpp = "" + for inputs in self.suite_def.input_cases: + if not self.suite_def.requires_prepack: + benchmarks_cpp += self.generate_benchmark_case(inputs) + if self.suite_def.supports_prepack(): + benchmarks_cpp += self.generate_benchmark_case(inputs, prepack=True) + return benchmarks_cpp + + def generate_benchmark_fixture(self) -> str: + build_graph_fn = "" + benchmark_fn = "" + if not self.suite_def.requires_prepack: + build_graph_fn = self.generator.gen_build_graph_fn() + benchmark_fn = self.generator.gen_op_exec_graph_fn() + + prepacked_build_graph_fn = "" + prepacked_benchmark_fn = "" + if self.suite_def.supports_prepack(): + self.generator.should_prepack = True + prepacked_build_graph_fn = self.generator.gen_build_graph_fn() + build_graph_fn += "\n\n " + build_graph_fn += prepacked_build_graph_fn + prepacked_benchmark_fn = self.generator.gen_op_exec_graph_fn() + benchmark_fn += "\n\n " + benchmark_fn += prepacked_benchmark_fn + + return benchmark_fixture_template.format( + op_name=self.op_name, + build_graph_fn=build_graph_fn, + benchmark_fn=benchmark_fn, + rtol=self.suite_def.rtol, + arg_valuerefs=self.generator.gen_arg_valueref_decls(), + atol=self.suite_def.atol, + ) + + +########################## +## Test File Generation ## +########################## + +cpp_test_template = """ +#include +#include +#include + +#include +#include +#include + +using namespace vkcompute; +using TensorOptions = at::TensorOptions; + +vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {{ + switch (at_scalartype) {{ + case c10::kFloat: + return vkapi::kFloat; + case c10::kHalf: + return vkapi::kHalf; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + return vkapi::kInt; + case c10::kChar: + return vkapi::kChar; + default: + VK_THROW("Unsupported at::ScalarType!"); + }} +}} + +at::Tensor make_rand_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + if (high == 1.0 && low == 0.0) + return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); + + if (dtype == at::kChar) + return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype)); + + return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low; +}} + +at::Tensor make_seq_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + (void)low; + (void)high; + + int64_t n = 1; + for (auto size: sizes) {{ + n *= size; + }} + + std::vector values(n); + for (int i=0;i indices) {{ + at::ScalarType dtype = at::kInt; + std::vector sizes = {{static_cast(indices.size())}}; + + // Clone as original data will be deallocated upon return. + return at::from_blob(indices.data(), sizes, dtype).detach().clone(); +}} + +at::Tensor make_index_tensor(std::vector> indices) {{ + at::ScalarType dtype = at::kInt; + std::vector sizes = {{ + static_cast(indices.size()), + static_cast(indices[0].size())}}; + + // Flatten indices as from_blob reads garbage otherwise. + std::vector acc; + for (auto& vec: indices) {{ + acc.insert(acc.end(), vec.begin(), vec.end()); + }} + + // Clone as original data will be deallocated upon return. + return at::from_blob(acc.data(), sizes, dtype).detach().clone(); +}} + +at::Tensor make_index_tensor(std::vector>> indices) {{ + at::ScalarType dtype = at::kInt; + std::vector sizes = {{ + static_cast(indices.size()), + static_cast(indices[0].size()), + static_cast(indices[0][0].size())}}; + + // Flatten indices as from_blob reads garbage otherwise. + std::vector acc; + for (auto& v: indices) {{ + for (auto& vv: v) {{ + acc.insert(acc.end(), vv.begin(), vv.end()); + }} + }} + + // Clone as original data will be deallocated upon return. + return at::from_blob(acc.data(), sizes, dtype).detach().clone(); +}} + +using ShaderEntry = std::tuple; +using QueryPoolResults = std::vector; +using ShaderTimes = std::unordered_map>; + +void process_querypool_results( + QueryPoolResults& results, + ShaderTimes& shader_times) {{ + for (const ShaderEntry& entry : results) {{ + std::string kernel_name = std::get<0>(entry); + std::uint64_t start_ns = std::get<2>(entry); + std::uint64_t end_ns = std::get<3>(entry); + std::uint64_t duration_ns = end_ns - start_ns; + if (shader_times.find(kernel_name) == shader_times.end()) {{ + shader_times[kernel_name] = std::vector(); + }} + shader_times[kernel_name].emplace_back(duration_ns); + }} +}} + +void register_shader_time_counters( + benchmark::State& state, + ShaderTimes& shader_times) {{ + for (auto& times_list : shader_times) {{ + // Filter to_nchw and nchw_to shaders + if (times_list.first.find("to_nchw") != std::string::npos) {{ + continue; + }} + if (times_list.first.find("nchw_to") != std::string::npos) {{ + continue; + }} + + std::sort(times_list.second.begin(), times_list.second.end()); + uint64_t median_time; + median_time = times_list.second[times_list.second.size() / 2]; + state.counters[times_list.first + " median ns"] = median_time; + }} +}} + +{benchmark_fixtures} + +{def_benchmarks} +""" + + +class VkBenchmarkFileGen: + def __init__(self, out_path): + self.out_path = out_path + self.suites_gens = [] + + def add_suite(self, op_reg_name: str, f: NativeFunction, all_input_cases) -> None: + suites_gen = VkBenchmarkGen(op_reg_name, f, all_input_cases) + self.suites_gens.append(suites_gen) + + def generate_benchmarks_cpp(self) -> str: + return "\n".join([h.generate_benchmark() for h in self.suites_gens]) + + def generate_benchmark_fixtures(self) -> str: + return "\n".join([h.generate_benchmark_fixture() for h in self.suites_gens]) + + def generate_cpp(self) -> str: + return cpp_test_template.format( + benchmark_fixtures=self.generate_benchmark_fixtures(), + def_benchmarks=self.generate_benchmarks_cpp(), + ) diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index c0583bbec4d..0bc0a93dbcb 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -228,11 +228,12 @@ def create_aten_method_call(self) -> str: func_call = f"ATEN_FN({self.f_sig.name()})({exprs});" return func_call - def create_out_src(self) -> str: + def create_out_src(self, include_declarations: bool = True) -> str: + cpp_type = self.out.cpp_type if include_declarations else "" if Variant.function in self.f.variants: - return f"{self.out.cpp_type} out = " + self.create_aten_fn_call() + "\n" + return f"{cpp_type} out = " + self.create_aten_fn_call() + "\n" else: - return f"{self.out.cpp_type} out = " + self.create_aten_method_call() + "\n" + return f"{cpp_type} out = " + self.create_aten_method_call() + "\n" ## Graph code generation utils @@ -242,7 +243,28 @@ def prepack_ref(self, ref: ValueRef) -> bool: else: return ref.supports_prepack and self.should_prepack - def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 + def create_value_decl_for(self, ref: ValueRefList) -> str: # noqa: C901 + if isinstance(ref, list): + ret_str = "" + for r in ref: + ret_str += self.create_value_decl_for(r) + return ret_str + + cpp_type = "IOValueRef" if (ref.is_in or ref.requires_prepack) else "ValueRef" + if ref.src_cpp_type == AT_TENSOR_LIST: + ret_str = f"std::vector {ref.name}_io_value_refs;\n" + ret_str += f"std::vector {ref.name}_value_refs;\n" + return ret_str + elif ref.src_cpp_type == TENSOR_VECTOR: + ret_str = f"std::vector {ref.io_value_list_name};\n" + ret_str += f"std::vector {ref.value_list_name};\n" + return ret_str + else: + return f"{cpp_type} {ref.name};\n" + + def create_value_for( # noqa: C901 + self, ref: ValueRefList, include_declarations: bool = True + ) -> str: if isinstance(ref, list): ret_str = "" for r in ref: @@ -252,9 +274,16 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 prepack = self.prepack_ref(ref) cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef" + if not include_declarations: + cpp_type = "" if ref.src_cpp_type == OPT_AT_TENSOR: ret_str = f"{cpp_type} {ref.name} = " + if prepack: + ret_str = "" + if include_declarations: + ret_str += f"IOValueRef {ref.name};\n" + ret_str += f"{ref.name}.value = " ret_str += f"!{ref.src_cpp_name}.has_value() ? " ret_str += f"{self.graph}{self.dot}add_none() : " if not prepack: @@ -291,11 +320,13 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 # each tensor, to facilate staging. On the other hand, we will # use the .value tensor to create a ValueList, which will be passed # to the corresponding ops. - ret_str = f"std::vector {ref.name}_io_value_refs;\n" - ret_str += f"std::vector {ref.name}_value_refs;\n" + ret_str = "" + if include_declarations: + ret_str += f"std::vector {ref.name}_io_value_refs;\n" + ret_str += f"std::vector {ref.name}_value_refs;\n" ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n" ret_str += ( - f" {cpp_type} io_value_ref = {self.graph}{self.dot}add_input_tensor(\n" + f" IOValueRef io_value_ref = {self.graph}{self.dot}add_input_tensor(\n" ) ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n" ret_str += ( @@ -307,9 +338,11 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n" return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: - ret_str = f""" -std::vector {ref.io_value_list_name}; -std::vector {ref.value_list_name}; + ret_str = "" + if include_declarations: + ret_str += f"std::vector {ref.io_value_list_name};\n" + ret_str += f"std::vector {ref.value_list_name};\n" + ret_str += f""" for (int i=0; i str: # noqa: C901 return ret_str ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}" + if prepack: + ret_str = "" + if include_declarations: + ret_str = f"IOValueRef {ref.name};\n" + ret_str += f"{ref.name}.value = {self.graph}{self.dot}" + if ref.src_cpp_type == AT_TENSOR and not prepack: ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}.sizes().vec(), " @@ -374,14 +413,29 @@ def create_op_call(self) -> str: else: op_create_code += ( f"{ref.name}.value, " - if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out + if ref.is_in or ref.requires_prepack or ref.is_out else f"{ref.name}, " ) + # op_create_code += f"{ref.name}, " op_create_code += "out_ref});\n" return op_create_code - def set_output(self, ref: ValueRefList) -> str: + def gen_output_staging_valueref_decl(self, ref: ValueRefList) -> str: + if isinstance(ref, list): + ret_str = "" + for r in ref[:-1]: + ret_str += self.gen_output_staging_valueref_decl(r) + return ret_str + elif ref.src_cpp_type == TENSOR_VECTOR: + assert ref.is_out + ret_str = "" + return ret_str + + assert ref.src_cpp_type == AT_TENSOR and ref.is_out + return f"ValueRef {ref.name}_staging;\n" + + def set_output(self, ref: ValueRefList, include_declarations: bool = True) -> str: if isinstance(ref, list): ret_str = "" for r in ref[:-1]: @@ -398,7 +452,8 @@ def set_output(self, ref: ValueRefList) -> str: return ret_str assert ref.src_cpp_type == AT_TENSOR and ref.is_out - ret_str = f"ValueRef {ref.name}_staging = {self.graph}{self.dot}" + cpptype = "ValueRef" if include_declarations else "" + ret_str = f"{cpptype} {ref.name}_staging = {self.graph}{self.dot}" ret_str += f"set_output_tensor({ref.name});\n" return ret_str @@ -516,15 +571,28 @@ def check_graph_out(self, ref: ValueRefList) -> str: ## Top level code generation - def gen_graph_build_code(self) -> str: - graph_build = self.create_out_src() + def gen_arg_valueref_decls(self) -> str: + ret_str = "" for aten_arg in self.args: - graph_build += self.create_value_for(self.refs[aten_arg.name]) + ref = self.refs[aten_arg.name] + ret_str += self.create_value_decl_for(ref) + + ret_str += self.create_value_decl_for(self.refs["out"]) + ret_str += f"{self.out.cpp_type} out;\n" + ret_str += self.gen_output_staging_valueref_decl(self.refs["out"]) + return ret_str + + def gen_graph_build_code(self, include_declarations: bool = True) -> str: + graph_build = self.create_out_src(include_declarations) + for aten_arg in self.args: + graph_build += self.create_value_for( + self.refs[aten_arg.name], include_declarations + ) - graph_build += self.create_value_for(self.refs["out"]) + graph_build += self.create_value_for(self.refs["out"], include_declarations) graph_build += self.create_op_call() - graph_build += self.set_output(self.refs["out"]) + graph_build += self.set_output(self.refs["out"], include_declarations) graph_build += f"{self.graph}{self.dot}prepare();\n" graph_build += f"{self.graph}{self.dot}encode_prepack();\n" @@ -534,7 +602,7 @@ def gen_graph_build_code(self) -> str: graph_build += "\n" return graph_build - def gen_graph_exec_code(self) -> str: + def gen_graph_exec_code(self, check_output=True) -> str: graph_exec = "" for aten_arg in self.args: ref = self.refs[aten_arg.name] @@ -547,26 +615,27 @@ def gen_graph_exec_code(self) -> str: graph_exec += self.declare_vk_out_for(self.refs["out"]) graph_exec += self.copy_from_staging(self.refs["out"]) - graph_exec += self.check_graph_out(self.refs["out"]) + if check_output: + graph_exec += self.check_graph_out(self.refs["out"]) graph_exec = re.sub(r"^", " ", graph_exec, flags=re.M) graph_exec = "{\n" + graph_exec + "\n}" return graph_exec - def gen_conditional_skips(self) -> str: + def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str: fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n" - fp16_skip += " GTEST_SKIP();\n" + fp16_skip += f" {skip_str}\n" fp16_skip += "}" fp16_skip = re.sub(r"^", " ", fp16_skip, flags=re.M) + "\n" int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n" - int8_skip += " GTEST_SKIP();\n" + int8_skip += f" {skip_str};\n" int8_skip += "}\n" skips = "" - skips = "if (test_dtype == at::kHalf) {\n" + skips += "if (test_dtype == at::kHalf) {\n" skips += fp16_skip skips += "}\n" @@ -595,3 +664,33 @@ def gen_op_check_fn(self) -> str: op_check_fn += "\n }" return op_check_fn + + def gen_build_graph_fn(self, include_declarations: bool = False) -> str: + op_name = self.f.func.name.unambiguous_name() + op_build_graph_fn = self.gen_decl(f"build_graph_{op_name}") + " {\n" + if self.should_prepack: + op_build_graph_fn = ( + self.gen_decl(f"prepacked_build_graph_{op_name}") + " {\n" + ) + + op_build_graph_fn_body = "" + op_build_graph_fn_body += self.gen_graph_build_code(include_declarations) + + op_build_graph_fn += op_build_graph_fn_body + op_build_graph_fn += "\n }" + return op_build_graph_fn + + def gen_op_exec_graph_fn(self) -> str: + op_name = self.f.func.name.unambiguous_name() + op_benchmark_fn = self.gen_decl(f"benchmark_{op_name}") + " {\n" + if self.should_prepack: + op_benchmark_fn = self.gen_decl(f"prepacked_benchmark_{op_name}") + " {\n" + + op_benchmark_fn_body = "" + op_benchmark_fn_body += self.gen_graph_exec_code(False) + + op_benchmark_fn_body = re.sub(r"^", " ", op_benchmark_fn_body, flags=re.M) + + op_benchmark_fn += op_benchmark_fn_body + op_benchmark_fn += "\n }" + return op_benchmark_fn