From 4a796683cd43c97c226157ac45ef6b943763541d Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 9 May 2025 08:47:45 -0700 Subject: [PATCH] [ET-VK] Moving device capabilities check to DispatchNode and PrepackNode ctor. The changes in this diff move the device capabilities check from the encode method to the constructor of DispatchNode and PrepackNode classes. Differential Revision: [D74481839](https://our.internmc.facebook.com/intern/diff/D74481839/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/DispatchNode.cpp | 3 +-- backends/vulkan/runtime/graph/ops/PrepackNode.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 51ff0c122b0..30b68859cff 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -33,6 +33,7 @@ DispatchNode::DispatchNode( spec_vars_(spec_vars), push_constants_(push_constants) { graph.update_descriptor_counts(shader, /*execute = */ true); + graph.context()->check_device_capabilities(shader_); } void DispatchNode::encode(ComputeGraph* graph) { @@ -42,8 +43,6 @@ void DispatchNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); vkapi::PipelineBarrier pipeline_barrier{}; - context->check_device_capabilities(shader_); - std::unique_lock cmd_lock = context->dispatch_lock(); std::array push_constants_data; diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index d84d893540c..96262f0b3e0 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -45,6 +45,7 @@ PrepackNode::PrepackNode( push_constants_(push_constants) { graph.update_descriptor_counts(shader, /*execute = */ false); graph.update_descriptor_counts(noop_shader_, /*execute = */ false); + graph.context()->check_device_capabilities(shader_); } api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { @@ -70,8 +71,6 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { void PrepackNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); - context->check_device_capabilities(shader_); - vTensorPtr packed = graph->get_tensor(packed_); api::StagingBuffer staging = create_staging_buffer(graph);