diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index b32f4eb4308..02df85c33e8 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->encode_prepack(); compute_graph->prepack(); + // TODO(ssjia): remove this once we can batch compile compute pipelines + // during prepare(). compute_graph->encode_execute(); return Error::Ok; @@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } + // propagate_resize() will re-encode the command buffer so that push + // constants are updated and DynamicDispatchNode can update the compute + // shader, global workgroup size, and local workgroup size to perform the + // model inference. if (should_propagate_resize) { compute_graph->propagate_resize(); } + compute_graph->execute(); for (size_t i = 0; i < compute_graph->outputs().size(); i++) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index a4a6abdd63f..be9eae352ec 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -678,11 +678,12 @@ void ComputeGraph::encode_execute() { } } -void ComputeGraph::execute() const { +void ComputeGraph::execute() { vkapi::VulkanFence fence = context_->fences().get_fence(); context_->submit_cmd_to_gpu(fence.get_submit_handle()); fence.wait(); context_->fences().return_fence(fence); + execute_count_++; } void ComputeGraph::resize_input( @@ -696,6 +697,7 @@ void ComputeGraph::propagate_resize() { for (std::unique_ptr& node : execute_nodes_) { node->trigger_resize(this); } + encode_execute(); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 9f4bab3ac04..9f56941b184 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -187,6 +187,7 @@ class ComputeGraph final { protected: size_t values_in_use_ = 0; + size_t execute_count_ = 0; public: // @@ -745,7 +746,7 @@ class ComputeGraph final { // void encode_execute(); - void execute() const; + void execute(); // // Dynamic Shape support @@ -762,6 +763,10 @@ class ComputeGraph final { return context_->adapter_ptr()->supports_int16_shader_types(); } + inline size_t execute_count() const { + return execute_count_; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 51ff0c122b0..a0d3a4c2e5c 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) { std::unique_lock cmd_lock = context->dispatch_lock(); - std::array push_constants_data; - uint32_t push_constants_offset = 0; - - for (const auto& push_constant : push_constants_) { - push_constants_offset += push_constant.write( - push_constants_data.data(), - push_constants_offset, - kMaxPushConstantSize); - } + write_push_constant_data(); context->report_shader_dispatch_start( shader_.kernel_name, @@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) { node_id_); vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( - shader_, local_workgroup_size_, spec_vars_, push_constants_offset); + shader_, local_workgroup_size_, spec_vars_, push_constants_offset_); uint32_t idx = 0; idx = bind_values_to_descriptor_set( @@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) { pipeline_barrier, shader_, global_workgroup_size_, - push_constants_data.data(), - push_constants_offset); + push_constants_data_.data(), + push_constants_offset_); context->report_shader_dispatch_end(); } +void DispatchNode::write_push_constant_data() { + push_constants_offset_ = 0; + for (const auto& push_constant : push_constants_) { + push_constants_offset_ += push_constant.write( + push_constants_data_.data(), + push_constants_offset_, + kMaxPushConstantSize); + } +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index c45f0a741fd..db95adfee39 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode { const vkapi::SpecVarList spec_vars_; const std::vector push_constants_; + // For push constants + std::array push_constants_data_{}; + uint32_t push_constants_offset_ = 0; + + void write_push_constant_data(); + public: operator bool() const { return shader_; diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index ac84916c6fa..a8d2fe2e99d 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode( const ResizeFunction& resize_fn) : DispatchNode( graph, - pick_shader_fn(&graph, args, resize_args), - pick_global_wg_fn(&graph, args, resize_args), - pick_local_wg_fn(&graph, args, resize_args), + vkapi::ShaderInfo(), + {1u, 1u, 1u}, + {1u, 1u, 1u}, args, params, push_constants, @@ -36,13 +36,57 @@ DynamicDispatchNode::DynamicDispatchNode( resize_fn), pick_shader_fn_(pick_shader_fn), pick_global_wg_fn_(pick_global_wg_fn), + pick_local_wg_fn_(pick_local_wg_fn) { + shader_ = pick_shader_fn(&graph, args, resize_args); + global_workgroup_size_ = + pick_global_wg_fn(&graph, shader_, args, resize_args); + local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( + &graph, shader_, global_workgroup_size_, args, resize_args)); +} + +DynamicDispatchNode::DynamicDispatchNode( + ComputeGraph& graph, + const vkapi::ShaderInfo& shader, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn) + : DispatchNode( + graph, + shader, + pick_global_wg_fn(&graph, shader, args, resize_args), + pick_local_wg_fn( + &graph, + shader, + pick_global_wg_fn(&graph, shader, args, resize_args), + args, + resize_args), + args, + params, + push_constants, + spec_vars, + resize_args, + resize_fn), + pick_shader_fn_{nullptr}, + pick_global_wg_fn_(pick_global_wg_fn), pick_local_wg_fn_(pick_local_wg_fn) {} void DynamicDispatchNode::encode(ComputeGraph* graph) { - shader_ = pick_shader_fn_(graph, args_, resize_args_); - global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_); - local_workgroup_size_ = - utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_)); + if (pick_shader_fn_) { + shader_ = pick_shader_fn_(graph, args_, resize_args_); + } + if (pick_global_wg_fn_) { + global_workgroup_size_ = + pick_global_wg_fn_(graph, shader_, args_, resize_args_); + } + if (pick_local_wg_fn_) { + local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_( + graph, shader_, global_workgroup_size_, args_, resize_args_)); + } DispatchNode::encode(graph); } diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h index ede50941415..005151272c3 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -32,10 +32,13 @@ class DynamicDispatchNode final : public DispatchNode { const std::vector&)>; using PickGlobalFn = const std::function&, const std::vector&)>; using PickLocalFn = const std::function&, const std::vector&)>; @@ -51,6 +54,18 @@ class DynamicDispatchNode final : public DispatchNode { const std::vector& resize_args, const ResizeFunction& resize_fn = nullptr); + explicit DynamicDispatchNode( + ComputeGraph& graph, + const vkapi::ShaderInfo& shader, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn = nullptr); + ~DynamicDispatchNode() override = default; void encode(ComputeGraph* graph) override; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 7563fc63c71..0731722e13a 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -65,7 +65,7 @@ class ExecuteNode { (void)graph; } - inline void trigger_resize(ComputeGraph* graph) { + virtual inline void trigger_resize(ComputeGraph* graph) { if (resize_fn_ != nullptr) { resize_fn_(graph, args_, resize_args_); } diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a6475d95d07..60dfb3b8606 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -1660,9 +1661,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { for (auto& new_sizes : new_sizes_list) { graph.get_tensor(a.value)->virtual_resize(new_sizes); graph.get_tensor(b.value)->virtual_resize(new_sizes); - graph.get_tensor(c)->virtual_resize(new_sizes); graph.get_tensor(d.value)->virtual_resize(new_sizes); - graph.get_tensor(e)->virtual_resize(new_sizes); + graph.propagate_resize(); float val_a = new_sizes[1] + 4.0f; float val_b = new_sizes[2] + 1.5f; @@ -3315,17 +3315,23 @@ vkapi::ShaderInfo pick_dynamic_dispatch_shader( utils::uvec3 pick_dynamic_dispatch_global_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { + (void)shader; const ValueRef out = args[0].refs[0]; - return graph->logical_limits_of(out); } utils::uvec3 pick_dynamic_dispatch_local_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; return {64, 1, 1}; }