Skip to content

Commit 26be400

Browse files
committed
[ET-VK][ez] Misc fixes related to extension support checking
## Context Follow up from #7576. Apply two "fixes" that were missed in the first diff: 1. Check device capability for Prepacking nodes as well 2. Remove conditional skips during generated operator correctness tests; rely on the device capability check to determine if a skip is needed. Differential Revision: [D68035430](https://our.internmc.facebook.com/intern/diff/D68035430/) [ghstack-poisoned]
1 parent be8d304 commit 26be400

File tree

3 files changed

+2
-32
lines changed

3 files changed

+2
-32
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,6 @@ vTensor::vTensor(
478478
if (storage_type != utils::kBuffer) {
479479
set_logical_limits(storage_.image_extents_);
480480
}
481-
482-
if (dtype == vkapi::kHalf) {
483-
VK_CHECK_COND(
484-
api::context()->adapter_ptr()->supports_16bit_storage_buffers(),
485-
"Half dtype is only available if the physical device supports float16 "
486-
"storage buffers!");
487-
}
488481
}
489482

490483
// NOLINTNEXTLINE

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
6868
void PrepackNode::encode(ComputeGraph* graph) {
6969
api::Context* const context = graph->context();
7070

71+
context->check_device_capabilities(shader_);
72+
7173
vTensorPtr packed = graph->get_tensor(packed_);
7274
api::StagingBuffer staging = create_staging_buffer(graph);
7375

backends/vulkan/test/op_tests/utils/gen_computegraph.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -633,30 +633,6 @@ def gen_graph_exec_code(self, check_output=True) -> str:
633633

634634
return graph_exec
635635

636-
def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str:
637-
fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n"
638-
fp16_skip += f" {skip_str}\n"
639-
fp16_skip += "}"
640-
fp16_skip = re.sub(r"^", " ", fp16_skip, flags=re.M) + "\n"
641-
642-
int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n"
643-
int8_skip += f" {skip_str};\n"
644-
int8_skip += "}\n"
645-
646-
skips = ""
647-
648-
skips += "if (test_dtype == at::kHalf) {\n"
649-
skips += fp16_skip
650-
skips += "}\n"
651-
652-
for _, dtype in self.suite_def.arg_dtype.items():
653-
if dtype == "at::kChar" or dtype == "at::kQInt8":
654-
skips += int8_skip
655-
continue
656-
657-
skips += "\n"
658-
return skips
659-
660636
def gen_op_check_fn(self) -> str:
661637
op_name = self.f.func.name.unambiguous_name()
662638
if self.suite_def.test_name_suffix is not None:
@@ -667,7 +643,6 @@ def gen_op_check_fn(self) -> str:
667643
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n"
668644

669645
op_check_fn_body = ""
670-
op_check_fn_body += self.gen_conditional_skips()
671646
op_check_fn_body += self.gen_graph_build_code()
672647
op_check_fn_body += self.gen_graph_exec_code()
673648

0 commit comments

Comments
 (0)