From 167c063261833bbda0772d27674c1fea44ab0908 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 2 Jun 2025 15:30:49 -0700 Subject: [PATCH 1/2] [ET-VK][ez] Enable dynamic shape support when using push constants Pull Request resolved: https://github.com/pytorch/executorch/pull/11253 ## Changes * Call `encode_execute()` upon resize in `propagate_resize()` * Minor update to `DispatchNode` to store push constant data array as a persistent member of the class ## Motivation Passing in tensor metadata (i.e. sizes, strides) via push constants is typically more performant than passing them via a UBO (uniform buffer object). However, currently dynamic shapes do not work when push constants are used as I realized that the tensor metadata contained in the push constants do not get updated. It appears that that `vkCmdPushConstants` sets the push constants when encoding the command buffer, however the push constants will not be updated if the command buffer is submitted for execution multiple times. Therefore, to update push constant values **the command buffer needs to be re-encoded**. ## Performance Impact This may add a small performance overhead (i.e. re-encoding the command buffer) when executing models with dynamic shapes. Models that do not trigger tensor resizing will not be impacted. However, I measured the impact on a llama 3.2 1B model and the impact of re-encoding a command buffer appears to be negligible. In any case, re-encoding the command buffer is a "necessary evil" when working with dynamic shapes, otherwise the tensor metadata seen by shaders may never get updated. Furthermore, re-encoding the command buffer can allow an opportunity to adjust global work group sizing to match current tensor sizes, which may have a huge performance impact when maximum tensor sizes far exceeds what tensor sizes will realistically be during inference (one instance of this is for transformer models when the max sequence length is very long). ghstack-source-id: 287711101 @exported-using-ghexport Differential Revision: [D75686051](https://our.internmc.facebook.com/intern/diff/D75686051/) --- backends/vulkan/runtime/VulkanBackend.cpp | 7 +++++ .../vulkan/runtime/graph/ComputeGraph.cpp | 4 ++- backends/vulkan/runtime/graph/ComputeGraph.h | 7 ++++- .../vulkan/runtime/graph/ops/DispatchNode.cpp | 26 ++++++++++--------- .../vulkan/runtime/graph/ops/DispatchNode.h | 6 +++++ .../vulkan/runtime/graph/ops/ExecuteNode.h | 2 +- .../vulkan/test/vulkan_compute_api_test.cpp | 3 +-- 7 files changed, 38 insertions(+), 17 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index b32f4eb4308..02df85c33e8 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->encode_prepack(); compute_graph->prepack(); + // TODO(ssjia): remove this once we can batch compile compute pipelines + // during prepare(). compute_graph->encode_execute(); return Error::Ok; @@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } + // propagate_resize() will re-encode the command buffer so that push + // constants are updated and DynamicDispatchNode can update the compute + // shader, global workgroup size, and local workgroup size to perform the + // model inference. if (should_propagate_resize) { compute_graph->propagate_resize(); } + compute_graph->execute(); for (size_t i = 0; i < compute_graph->outputs().size(); i++) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index a4a6abdd63f..be9eae352ec 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -678,11 +678,12 @@ void ComputeGraph::encode_execute() { } } -void ComputeGraph::execute() const { +void ComputeGraph::execute() { vkapi::VulkanFence fence = context_->fences().get_fence(); context_->submit_cmd_to_gpu(fence.get_submit_handle()); fence.wait(); context_->fences().return_fence(fence); + execute_count_++; } void ComputeGraph::resize_input( @@ -696,6 +697,7 @@ void ComputeGraph::propagate_resize() { for (std::unique_ptr& node : execute_nodes_) { node->trigger_resize(this); } + encode_execute(); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 9f4bab3ac04..9f56941b184 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -187,6 +187,7 @@ class ComputeGraph final { protected: size_t values_in_use_ = 0; + size_t execute_count_ = 0; public: // @@ -745,7 +746,7 @@ class ComputeGraph final { // void encode_execute(); - void execute() const; + void execute(); // // Dynamic Shape support @@ -762,6 +763,10 @@ class ComputeGraph final { return context_->adapter_ptr()->supports_int16_shader_types(); } + inline size_t execute_count() const { + return execute_count_; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 51ff0c122b0..a0d3a4c2e5c 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) { std::unique_lock cmd_lock = context->dispatch_lock(); - std::array push_constants_data; - uint32_t push_constants_offset = 0; - - for (const auto& push_constant : push_constants_) { - push_constants_offset += push_constant.write( - push_constants_data.data(), - push_constants_offset, - kMaxPushConstantSize); - } + write_push_constant_data(); context->report_shader_dispatch_start( shader_.kernel_name, @@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) { node_id_); vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( - shader_, local_workgroup_size_, spec_vars_, push_constants_offset); + shader_, local_workgroup_size_, spec_vars_, push_constants_offset_); uint32_t idx = 0; idx = bind_values_to_descriptor_set( @@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) { pipeline_barrier, shader_, global_workgroup_size_, - push_constants_data.data(), - push_constants_offset); + push_constants_data_.data(), + push_constants_offset_); context->report_shader_dispatch_end(); } +void DispatchNode::write_push_constant_data() { + push_constants_offset_ = 0; + for (const auto& push_constant : push_constants_) { + push_constants_offset_ += push_constant.write( + push_constants_data_.data(), + push_constants_offset_, + kMaxPushConstantSize); + } +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index c45f0a741fd..db95adfee39 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode { const vkapi::SpecVarList spec_vars_; const std::vector push_constants_; + // For push constants + std::array push_constants_data_{}; + uint32_t push_constants_offset_ = 0; + + void write_push_constant_data(); + public: operator bool() const { return shader_; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 7563fc63c71..0731722e13a 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -65,7 +65,7 @@ class ExecuteNode { (void)graph; } - inline void trigger_resize(ComputeGraph* graph) { + virtual inline void trigger_resize(ComputeGraph* graph) { if (resize_fn_ != nullptr) { resize_fn_(graph, args_, resize_args_); } diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a6475d95d07..f014cc79f56 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1660,9 +1660,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { for (auto& new_sizes : new_sizes_list) { graph.get_tensor(a.value)->virtual_resize(new_sizes); graph.get_tensor(b.value)->virtual_resize(new_sizes); - graph.get_tensor(c)->virtual_resize(new_sizes); graph.get_tensor(d.value)->virtual_resize(new_sizes); - graph.get_tensor(e)->virtual_resize(new_sizes); + graph.propagate_resize(); float val_a = new_sizes[1] + 4.0f; float val_b = new_sizes[2] + 1.5f; From f95e6cb5aafcd595dc209865925ce2c333f9612f Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 2 Jun 2025 15:30:51 -0700 Subject: [PATCH 2/2] [ET-VK][ez] Updates to DynamicDispatchNode Pull Request resolved: https://github.com/pytorch/executorch/pull/11254 ## Changes For `DynamicDispatchNode`: * Pass in global work group size to the local work group size determination function * Add additional constructor for which the shader is not dynamic * During `encode`, check that pick functions are not `nullptr` ## Motivation Oftentimes it is useful to know what the global work group size is when determining what the local group group size should be. ## Performance Impact None. ghstack-source-id: 287711100 @exported-using-ghexport Differential Revision: [D75686047](https://our.internmc.facebook.com/intern/diff/D75686047/) --- .../runtime/graph/ops/DynamicDispatchNode.cpp | 58 ++++++++++++++++--- .../runtime/graph/ops/DynamicDispatchNode.h | 15 +++++ .../vulkan/test/vulkan_compute_api_test.cpp | 13 ++++- 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index ac84916c6fa..a8d2fe2e99d 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode( const ResizeFunction& resize_fn) : DispatchNode( graph, - pick_shader_fn(&graph, args, resize_args), - pick_global_wg_fn(&graph, args, resize_args), - pick_local_wg_fn(&graph, args, resize_args), + vkapi::ShaderInfo(), + {1u, 1u, 1u}, + {1u, 1u, 1u}, args, params, push_constants, @@ -36,13 +36,57 @@ DynamicDispatchNode::DynamicDispatchNode( resize_fn), pick_shader_fn_(pick_shader_fn), pick_global_wg_fn_(pick_global_wg_fn), + pick_local_wg_fn_(pick_local_wg_fn) { + shader_ = pick_shader_fn(&graph, args, resize_args); + global_workgroup_size_ = + pick_global_wg_fn(&graph, shader_, args, resize_args); + local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( + &graph, shader_, global_workgroup_size_, args, resize_args)); +} + +DynamicDispatchNode::DynamicDispatchNode( + ComputeGraph& graph, + const vkapi::ShaderInfo& shader, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn) + : DispatchNode( + graph, + shader, + pick_global_wg_fn(&graph, shader, args, resize_args), + pick_local_wg_fn( + &graph, + shader, + pick_global_wg_fn(&graph, shader, args, resize_args), + args, + resize_args), + args, + params, + push_constants, + spec_vars, + resize_args, + resize_fn), + pick_shader_fn_{nullptr}, + pick_global_wg_fn_(pick_global_wg_fn), pick_local_wg_fn_(pick_local_wg_fn) {} void DynamicDispatchNode::encode(ComputeGraph* graph) { - shader_ = pick_shader_fn_(graph, args_, resize_args_); - global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_); - local_workgroup_size_ = - utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_)); + if (pick_shader_fn_) { + shader_ = pick_shader_fn_(graph, args_, resize_args_); + } + if (pick_global_wg_fn_) { + global_workgroup_size_ = + pick_global_wg_fn_(graph, shader_, args_, resize_args_); + } + if (pick_local_wg_fn_) { + local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_( + graph, shader_, global_workgroup_size_, args_, resize_args_)); + } DispatchNode::encode(graph); } diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h index ede50941415..005151272c3 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -32,10 +32,13 @@ class DynamicDispatchNode final : public DispatchNode { const std::vector&)>; using PickGlobalFn = const std::function&, const std::vector&)>; using PickLocalFn = const std::function&, const std::vector&)>; @@ -51,6 +54,18 @@ class DynamicDispatchNode final : public DispatchNode { const std::vector& resize_args, const ResizeFunction& resize_fn = nullptr); + explicit DynamicDispatchNode( + ComputeGraph& graph, + const vkapi::ShaderInfo& shader, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn = nullptr); + ~DynamicDispatchNode() override = default; void encode(ComputeGraph* graph) override; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index f014cc79f56..60dfb3b8606 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -3314,17 +3315,23 @@ vkapi::ShaderInfo pick_dynamic_dispatch_shader( utils::uvec3 pick_dynamic_dispatch_global_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { + (void)shader; const ValueRef out = args[0].refs[0]; - return graph->logical_limits_of(out); } utils::uvec3 pick_dynamic_dispatch_local_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; return {64, 1, 1}; }