diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index b32f4eb4308..02df85c33e8 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->encode_prepack(); compute_graph->prepack(); + // TODO(ssjia): remove this once we can batch compile compute pipelines + // during prepare(). compute_graph->encode_execute(); return Error::Ok; @@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } + // propagate_resize() will re-encode the command buffer so that push + // constants are updated and DynamicDispatchNode can update the compute + // shader, global workgroup size, and local workgroup size to perform the + // model inference. if (should_propagate_resize) { compute_graph->propagate_resize(); } + compute_graph->execute(); for (size_t i = 0; i < compute_graph->outputs().size(); i++) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index a4a6abdd63f..be9eae352ec 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -678,11 +678,12 @@ void ComputeGraph::encode_execute() { } } -void ComputeGraph::execute() const { +void ComputeGraph::execute() { vkapi::VulkanFence fence = context_->fences().get_fence(); context_->submit_cmd_to_gpu(fence.get_submit_handle()); fence.wait(); context_->fences().return_fence(fence); + execute_count_++; } void ComputeGraph::resize_input( @@ -696,6 +697,7 @@ void ComputeGraph::propagate_resize() { for (std::unique_ptr& node : execute_nodes_) { node->trigger_resize(this); } + encode_execute(); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 9f4bab3ac04..9f56941b184 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -187,6 +187,7 @@ class ComputeGraph final { protected: size_t values_in_use_ = 0; + size_t execute_count_ = 0; public: // @@ -745,7 +746,7 @@ class ComputeGraph final { // void encode_execute(); - void execute() const; + void execute(); // // Dynamic Shape support @@ -762,6 +763,10 @@ class ComputeGraph final { return context_->adapter_ptr()->supports_int16_shader_types(); } + inline size_t execute_count() const { + return execute_count_; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 51ff0c122b0..a0d3a4c2e5c 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) { std::unique_lock cmd_lock = context->dispatch_lock(); - std::array push_constants_data; - uint32_t push_constants_offset = 0; - - for (const auto& push_constant : push_constants_) { - push_constants_offset += push_constant.write( - push_constants_data.data(), - push_constants_offset, - kMaxPushConstantSize); - } + write_push_constant_data(); context->report_shader_dispatch_start( shader_.kernel_name, @@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) { node_id_); vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( - shader_, local_workgroup_size_, spec_vars_, push_constants_offset); + shader_, local_workgroup_size_, spec_vars_, push_constants_offset_); uint32_t idx = 0; idx = bind_values_to_descriptor_set( @@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) { pipeline_barrier, shader_, global_workgroup_size_, - push_constants_data.data(), - push_constants_offset); + push_constants_data_.data(), + push_constants_offset_); context->report_shader_dispatch_end(); } +void DispatchNode::write_push_constant_data() { + push_constants_offset_ = 0; + for (const auto& push_constant : push_constants_) { + push_constants_offset_ += push_constant.write( + push_constants_data_.data(), + push_constants_offset_, + kMaxPushConstantSize); + } +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index c45f0a741fd..db95adfee39 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode { const vkapi::SpecVarList spec_vars_; const std::vector push_constants_; + // For push constants + std::array push_constants_data_{}; + uint32_t push_constants_offset_ = 0; + + void write_push_constant_data(); + public: operator bool() const { return shader_; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 7563fc63c71..0731722e13a 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -65,7 +65,7 @@ class ExecuteNode { (void)graph; } - inline void trigger_resize(ComputeGraph* graph) { + virtual inline void trigger_resize(ComputeGraph* graph) { if (resize_fn_ != nullptr) { resize_fn_(graph, args_, resize_args_); } diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a6475d95d07..f014cc79f56 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1660,9 +1660,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { for (auto& new_sizes : new_sizes_list) { graph.get_tensor(a.value)->virtual_resize(new_sizes); graph.get_tensor(b.value)->virtual_resize(new_sizes); - graph.get_tensor(c)->virtual_resize(new_sizes); graph.get_tensor(d.value)->virtual_resize(new_sizes); - graph.get_tensor(e)->virtual_resize(new_sizes); + graph.propagate_resize(); float val_a = new_sizes[1] + 4.0f; float val_b = new_sizes[2] + 1.5f;