Skip to content

Commit e68552d

Browse files
authored
[ET-VK][benchmarking][ez] Don't perform copies when benchmarking
Differential Revision: D71570143 Pull Request resolved: #9468
1 parent 9a32875 commit e68552d

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.vulkan.test.op_tests.utils.gen_correctness_base import (
1313
CorrectnessTestGen,
1414
)
15-
from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
15+
from executorch.backends.vulkan.test.op_tests.utils.test_suite import VkTestSuite
1616

1717
from torchgen.model import NativeFunction
1818

@@ -72,10 +72,12 @@ class GeneratedOpBenchmark_{op_name} : public ::benchmark::Fixture {{
7272

7373

7474
class VkBenchmarkGen(CorrectnessTestGen):
75-
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: TestSuite):
75+
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite):
7676
super().__init__(f, inputs)
7777
self.op_reg_name = op_reg_name
78-
self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def)
78+
self.generator = ComputeGraphGen(
79+
self.op_reg_name, self.f, self.suite_def, inputs.force_io
80+
)
7981

8082
def gen_call_benchmark(self, prepack=False) -> str:
8183
test_str = f"benchmark_{self.op_name}("
@@ -197,7 +199,7 @@ def generate_benchmark_fixture(self) -> str:
197199
float high = 1.0) {{
198200
if (high == 1.0 && low == 0.0)
199201
return at::rand(sizes, at::device(at::kCPU).dtype(dtype));
200-
202+
201203
if (dtype == at::kChar)
202204
return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype));
203205

backends/vulkan/test/op_tests/utils/gen_computegraph.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,17 @@ def vk_out(self):
9090
class ComputeGraphGen:
9191
backend_key = None
9292

93-
def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
93+
def __init__(
94+
self,
95+
op_reg_name: str,
96+
f: NativeFunction,
97+
suite_def: TestSuite,
98+
include_io: bool = True,
99+
):
94100
self.op_reg_name = op_reg_name
95101
self.f = f
96102
self.suite_def = suite_def
103+
self.include_io = include_io
97104

98105
self.f_sig = CppSignatureGroup.from_native_function(
99106
self.f, method=False, fallback_binding=self.f.manual_cpp_binding
@@ -275,6 +282,10 @@ def create_value_for( # noqa: C901
275282
prepack = self.prepack_ref(ref)
276283
ref_is_view = self.suite_def.is_view_op and ref.is_out
277284

285+
# If skipping IO, force is_in to be False
286+
if not self.include_io and ref.is_in:
287+
ref.is_in = False
288+
278289
cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"
279290
if not include_declarations:
280291
cpp_type = ""
@@ -602,7 +613,8 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str:
602613
graph_build += self.create_value_for(self.refs["out"], include_declarations)
603614
graph_build += self.create_op_call()
604615

605-
graph_build += self.set_output(self.refs["out"], include_declarations)
616+
if self.include_io:
617+
graph_build += self.set_output(self.refs["out"], include_declarations)
606618

607619
graph_build += f"{self.graph}{self.dot}prepare();\n"
608620
graph_build += f"{self.graph}{self.dot}encode_prepack();\n"
@@ -614,18 +626,22 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str:
614626

615627
def gen_graph_exec_code(self, check_output=True) -> str:
616628
graph_exec = ""
617-
for aten_arg in self.args:
618-
ref = self.refs[aten_arg.name]
619-
if ref.is_in:
620-
graph_exec += self.virtual_resize(ref)
621-
graph_exec += self.copy_into_staging(ref)
629+
if self.include_io:
630+
for aten_arg in self.args:
631+
ref = self.refs[aten_arg.name]
632+
if ref.is_in:
633+
graph_exec += self.virtual_resize(ref)
634+
graph_exec += self.copy_into_staging(ref)
635+
636+
graph_exec += f"{self.graph}{self.dot}propagate_resize();\n"
622637

623-
graph_exec += f"{self.graph}{self.dot}propagate_resize();\n"
624638
graph_exec += f"{self.graph}{self.dot}execute();\n"
625639

626640
graph_exec += self.declare_vk_out_for(self.refs["out"])
627-
graph_exec += self.copy_from_staging(self.refs["out"])
628-
if check_output:
641+
if self.include_io:
642+
graph_exec += self.copy_from_staging(self.refs["out"])
643+
644+
if self.include_io and check_output:
629645
graph_exec += self.check_graph_out(self.refs["out"])
630646

631647
graph_exec = re.sub(r"^", " ", graph_exec, flags=re.M)

backends/vulkan/test/op_tests/utils/test_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ def __init__(self, input_cases: List[Any]):
4747
self.storage_types: List[str] = ["utils::kTexture3D"]
4848
self.layouts: List[str] = ["utils::kChannelsPacked"]
4949
self.data_gen: str = "make_rand_tensor"
50+
self.force_io: bool = True

0 commit comments

Comments
 (0)