Skip to content

Commit d37995f

Browse files
committed
[ET-VK][ez] Misc fixes related to extension support checking
Pull Request resolved: #7601 ## Context Follow up from #7576. Apply two "fixes" that were missed in the first diff: 1. Check device capability for Prepacking nodes as well 2. Remove conditional skips during generated operator correctness tests; rely on the device capability check to determine if a skip is needed. Differential Revision: [D68035430](https://our.internmc.facebook.com/intern/diff/D68035430/) ghstack-source-id: 260961960
1 parent c62f78c commit d37995f

File tree

3 files changed

+2
-8
lines changed

3 files changed

+2
-8
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,6 @@ vTensor::vTensor(
478478
if (storage_type != utils::kBuffer) {
479479
set_logical_limits(storage_.image_extents_);
480480
}
481-
482-
if (dtype == vkapi::kHalf) {
483-
VK_CHECK_COND(
484-
api::context()->adapter_ptr()->supports_16bit_storage_buffers(),
485-
"Half dtype is only available if the physical device supports float16 "
486-
"storage buffers!");
487-
}
488481
}
489482

490483
// NOLINTNEXTLINE

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
6868
void PrepackNode::encode(ComputeGraph* graph) {
6969
api::Context* const context = graph->context();
7070

71+
context->check_device_capabilities(shader_);
72+
7173
vTensorPtr packed = graph->get_tensor(packed_);
7274
api::StagingBuffer staging = create_staging_buffer(graph);
7375

backends/vulkan/test/op_tests/utils/gen_computegraph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,6 @@ def gen_op_check_fn(self) -> str:
667667
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n"
668668

669669
op_check_fn_body = ""
670-
op_check_fn_body += self.gen_conditional_skips()
671670
op_check_fn_body += self.gen_graph_build_code()
672671
op_check_fn_body += self.gen_graph_exec_code()
673672

0 commit comments

Comments
 (0)