From 26be4004cbed43735c122c01fc069e98477dd73d Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 10 Jan 2025 13:05:06 -0800 Subject: [PATCH 1/2] [ET-VK][ez] Misc fixes related to extension support checking ## Context Follow up from https://github.com/pytorch/executorch/pull/7576. Apply two "fixes" that were missed in the first diff: 1. Check device capability for Prepacking nodes as well 2. Remove conditional skips during generated operator correctness tests; rely on the device capability check to determine if a skip is needed. Differential Revision: [D68035430](https://our.internmc.facebook.com/intern/diff/D68035430/) [ghstack-poisoned] --- .../vulkan/runtime/api/containers/Tensor.cpp | 7 ------ .../vulkan/runtime/graph/ops/PrepackNode.cpp | 2 ++ .../test/op_tests/utils/gen_computegraph.py | 25 ------------------- 3 files changed, 2 insertions(+), 32 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 900854ccd75..8c76c11532b 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -478,13 +478,6 @@ vTensor::vTensor( if (storage_type != utils::kBuffer) { set_logical_limits(storage_.image_extents_); } - - if (dtype == vkapi::kHalf) { - VK_CHECK_COND( - api::context()->adapter_ptr()->supports_16bit_storage_buffers(), - "Half dtype is only available if the physical device supports float16 " - "storage buffers!"); - } } // NOLINTNEXTLINE diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index e27723468ab..bf501296b1b 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -68,6 +68,8 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { void PrepackNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); + context->check_device_capabilities(shader_); + vTensorPtr packed = graph->get_tensor(packed_); api::StagingBuffer staging = create_staging_buffer(graph); diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index 472127ffe2e..2fd1265c217 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -633,30 +633,6 @@ def gen_graph_exec_code(self, check_output=True) -> str: return graph_exec - def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str: - fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n" - fp16_skip += f" {skip_str}\n" - fp16_skip += "}" - fp16_skip = re.sub(r"^", " ", fp16_skip, flags=re.M) + "\n" - - int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n" - int8_skip += f" {skip_str};\n" - int8_skip += "}\n" - - skips = "" - - skips += "if (test_dtype == at::kHalf) {\n" - skips += fp16_skip - skips += "}\n" - - for _, dtype in self.suite_def.arg_dtype.items(): - if dtype == "at::kChar" or dtype == "at::kQInt8": - skips += int8_skip - continue - - skips += "\n" - return skips - def gen_op_check_fn(self) -> str: op_name = self.f.func.name.unambiguous_name() if self.suite_def.test_name_suffix is not None: @@ -667,7 +643,6 @@ def gen_op_check_fn(self) -> str: op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n" op_check_fn_body = "" - op_check_fn_body += self.gen_conditional_skips() op_check_fn_body += self.gen_graph_build_code() op_check_fn_body += self.gen_graph_exec_code() From f5dc4bb6c5d050eb8bc34af2f29d81257c430109 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 10 Jan 2025 13:48:56 -0800 Subject: [PATCH 2/2] Update on "[ET-VK][ez] Misc fixes related to extension support checking" ## Context Follow up from https://github.com/pytorch/executorch/pull/7576. Apply two "fixes" that were missed in the first diff: 1. Check device capability for Prepacking nodes as well 2. Remove conditional skips during generated operator correctness tests; rely on the device capability check to determine if a skip is needed. Differential Revision: [D68035430](https://our.internmc.facebook.com/intern/diff/D68035430/) [ghstack-poisoned] --- .../test/op_tests/utils/gen_computegraph.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index 2fd1265c217..6f93e662076 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -633,6 +633,30 @@ def gen_graph_exec_code(self, check_output=True) -> str: return graph_exec + def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str: + fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n" + fp16_skip += f" {skip_str}\n" + fp16_skip += "}" + fp16_skip = re.sub(r"^", " ", fp16_skip, flags=re.M) + "\n" + + int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n" + int8_skip += f" {skip_str};\n" + int8_skip += "}\n" + + skips = "" + + skips += "if (test_dtype == at::kHalf) {\n" + skips += fp16_skip + skips += "}\n" + + for _, dtype in self.suite_def.arg_dtype.items(): + if dtype == "at::kChar" or dtype == "at::kQInt8": + skips += int8_skip + continue + + skips += "\n" + return skips + def gen_op_check_fn(self) -> str: op_name = self.f.func.name.unambiguous_name() if self.suite_def.test_name_suffix is not None: