From 97c967308cae95b314abb65fa942b2de80cdcd99 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 13 Aug 2025 06:53:46 -0700 Subject: [PATCH 1/2] [ET-VK] Better work group sizes for matmul Pull Request resolved: https://github.com/pytorch/executorch/pull/13185 ## Context Currently `default_pick_local_wg_size()` (which internally calls `ComputeGraph::create_local_wg_size`) is used to select the local work group size for matrix multiplication ops. However, these functions currently bias the size of the local work group towards the largest dim of the global work group producing local wg sizes like ``` shader globalwg size localwg size =========== ===================== ==================== ============= linear_qga4w_tiled_texture3d_texture3d_texture2d_float {256, 29, 1} {32, 2, 1} 1487 matmul_naive_texture3d_float {29, 115, 32} {4, 2, 8} 712 ``` for matrix multiplication shaders. This behaviour was introduced in D64418632 / https://github.com/pytorch/executorch/pull/6409. However, through experimental testing a "square" work group size of `{8, 8, 1}` works a lot better for matrix multiplication shaders. The theoretical analysis for this behaviour is that the local work group size determines the memory locations that need to be loaded to compute the overall work group. For a work group with size `{W, H, 1}` the data required to compute the output would be `W * OUTPUT_TILE_W` columns of the weight tensor and `H * OUTPUT_TILE_H` rows of the input tensor. Note that all work group items in the same W index will be requesting the same columns from the weight tensor, and all work group items in the same H index will be requesting the same rows from the input tensor. If `H==W`, then that "balances" the amount of data needed to loaded from each input tensor and may result in better data sharing behaviour among all work group items. Assuming `OUTPUT_TILE_W == OUTPUT_TILE_H == 1`, a local work group of size `{64, 1, 1}` would require 1 unique row from the input tensor an 64 unique columns to be loaded from the weight tensor, resulting in `(1 + 64) * K = 65K` elements to be loaded in total, where K is the size of the shared reduction dim. Conversely, a local work group of size `{8, 8, 1}` would require 8 unique rows / 8 unique columns resulting in only `(8 + 8) * K = 16K` unique elements to be loaded. This highlights the need to use dedicated logic to compute work group sizes for matrix multiplication shaders. ## Changes * Introduce `pick_hw_square_wg_size` * Use the new local work group size determination function for Quantized Linear, Matmul, and Linear ghstack-source-id: 302703877 Differential Revision: [D79813236](https://our.internmc.facebook.com/intern/diff/D79813236/) --- .../vulkan/runtime/graph/ops/impl/Common.cpp | 23 +++++++++++++++++++ .../vulkan/runtime/graph/ops/impl/Common.h | 18 +++++++++++++++ .../vulkan/runtime/graph/ops/impl/Linear.cpp | 4 ++-- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 6 ++--- .../graph/ops/impl/QuantizedLinearQGANW.cpp | 3 ++- 5 files changed, 48 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 4c3c16417b5..6c701224f7f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -33,4 +33,27 @@ utils::uvec3 default_pick_local_wg_size( return graph->create_local_wg_size(global_workgroup_size); } +utils::uvec3 pick_hw_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)args; + (void)resize_args; + // Some inactive invocations are okay; set 6 as the threshold to use the + // a square wg size. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[1u] >= 6) { + return {8u, 8u, 1u}; + } + // If width dim is sufficiently small, then bias towards height dim to reduce + // the number of inactive invocations. + if (global_workgroup_size[0u] < 6u) { + return {4u, 16u, 1u}; + } + return {16u, 4u, 1u}; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index 662fb07095a..1831ab2a845 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -36,4 +36,22 @@ utils::uvec3 default_pick_local_wg_size( const std::vector& args, const std::vector& resize_args); +/** + * Constructs a local work group size with the shape {W, H, 1}. The function + * will try to set W == H == sqrt(num_invocations), where num_invocations is + * typically 64. This configuration is good for ops like matrix multiplication + * as it reduces the total volume of unique data that the entire work group + * will need to read from input tensors in order to produce the output data. + * To compute an output tile of {W, H, 1}, the work group will need to read + * H unique rows = H * K unique elements from the input tensor and W unique cols + * = W * K elements from the weight tensor, resulting in (W + H) * K unique + * elements in total. + */ +utils::uvec3 pick_hw_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 7ca31599cdf..38d70271f4f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -178,7 +178,7 @@ void add_addmm_naive_texture_node( graph, VK_KERNEL_FROM_STR(kernel_name), addmm_naive_texture_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers @@ -245,7 +245,7 @@ void add_addmm_naive_buffer_node( graph, VK_KERNEL_FROM_STR(kernel_name), addmm_naive_buffer_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 0f5556060a2..47ecf5f18d2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -102,7 +102,7 @@ void add_matmul_naive_buffer_node( graph, VK_KERNEL_FROM_STR(kernel_name), matmul_naive_buffer_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -158,7 +158,7 @@ void add_matmul_naive_texture3d_node( graph, pick_matmul_naive_texture3d_shader, default_pick_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -273,7 +273,7 @@ void add_matmul_optimized_node( graph, pick_matmul_optimized_shader, matmul_optimized_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index 8c7c6b0cdf9..52cf75e28b5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -158,7 +158,8 @@ utils::uvec3 linear_qga4w_local_wg_size( if (use_coop_algorithm) { return {64, 1, 1}; } else { - return graph->create_local_wg_size(global_workgroup_size); + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } } From 1fd109c3571c72d4efe07abeb230b70e0f946a95 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 13 Aug 2025 06:53:48 -0700 Subject: [PATCH 2/2] [ET-VK] Add mechanism to trigger command buffer re-encode only when necessary Pull Request resolved: https://github.com/pytorch/executorch/pull/13184 ## Context Dynamic shape models currently will require the command buffer to be re-encoded every inference. However, this introduces a significant overhead when running models that require dynamic shapes. The reality is that a command buffer re-encode may not be needed every frame. A command buffer re-encode will only be needed when: 1. Shader dispatch parameters change; i.e. new tensor sizes require a completely different compute shader, require new local work group sizing, or require new work group grid size (i.e. global work group size / local work group size) 2. Push constants containing tensor metadata need to be updated This diff aims to reduce the overhead of triggering tensor shape change by detecting when a command buffer re-encode is actually needed. ## Changes `ComputeGraph`: * Introduce `requires_reencode` flag to `ComputeGraph` to indicate when a command buffer re-encode is needed. * Introduce a `std::set` tracking which values were updated when propagating tensor sizes * "update" can be one of two things: 1) tensor sizes changed 2) symint value changed `DispatchNode`: * When propagating new tensor sizes, only execute the resize function if any of the values participating in the computation have been updated * Mark `requries_reencode` if any push constants associated with tensor metadata need to be udpated `DynamicDispatchNode`: * Only recompute compute shader dispatch params if any of the values participating in the computation have been updated * Mark `requires_reencode` if 1) a new compute shader is required, 2) local work group size changed, 3) work group grid size changed ghstack-source-id: 302703876 @exported-using-ghexport Differential Revision: [D79813237](https://our.internmc.facebook.com/intern/diff/D79813237/) --- backends/vulkan/runtime/VulkanBackend.cpp | 8 +-- .../vulkan/runtime/graph/ComputeGraph.cpp | 57 +++++++++++++-- backends/vulkan/runtime/graph/ComputeGraph.h | 38 ++++++++-- .../graph/containers/PushConstantData.h | 17 +++++ .../vulkan/runtime/graph/ops/DispatchNode.cpp | 17 +++++ .../vulkan/runtime/graph/ops/DispatchNode.h | 2 + .../runtime/graph/ops/DynamicDispatchNode.cpp | 69 +++++++++++++++++-- .../runtime/graph/ops/DynamicDispatchNode.h | 4 +- .../vulkan/runtime/graph/ops/ExecuteNode.cpp | 30 ++++++++ .../vulkan/runtime/graph/ops/ExecuteNode.h | 8 +-- backends/vulkan/runtime/utils/VecUtils.h | 23 +++++++ 11 files changed, 245 insertions(+), 28 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 8be4553b060..73b726bd32e 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -583,13 +583,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } - // propagate_resize() will re-encode the command buffer so that push - // constants are updated and DynamicDispatchNode can update the compute - // shader, global workgroup size, and local workgroup size to perform the - // model inference. - if (should_propagate_resize || - (compute_graph->graphconfig().expect_dynamic_shapes && - compute_graph->execute_count() == 0u)) { + if (should_propagate_resize) { compute_graph->propagate_resize(); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 3b9061701e6..acd20c9ee44 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -206,6 +206,29 @@ utils::StorageType ComputeGraph::suggested_storage_type() { return utils::kTexture3D; } +bool ComputeGraph::was_value_updated(const ValueRef idx) const noexcept { + if (!is_valid_value_idx(idx)) { + return false; + } + + // Check if this ValueRef itself was updated + if (updated_values_.find(idx) != updated_values_.end()) { + return true; + } + + // If this is a ValueList, check each ValueRef in the list + if (val_is_value_list(idx)) { + const auto& value_list = values_.at(idx).toConstValueList(); + for (const auto& nested_idx : value_list) { + if (was_value_updated(nested_idx)) { + return true; + } + } + } + + return false; +} + utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout( const std::vector& sizes) { if (config_.enable_memory_layout_override) { @@ -236,6 +259,10 @@ void ComputeGraph::check_no_active_value_ptrs() { "invalidated."); } +bool ComputeGraph::is_valid_value_idx(const ValueRef idx) const noexcept { + return idx >= 0 && idx < static_cast(values_.size()); +} + std::vector ComputeGraph::sizes_of(const ValueRef idx) const { const Value& val = values_.at(idx); if (val.isTensor()) { @@ -569,7 +596,12 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( } void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { - get_symint(idx)->set(val); + int32_t cur_val = read_symint(idx); + if (cur_val != val) { + get_symint(idx)->set(val); + // Track that this ValueRef was updated + updated_values_.insert(idx); + } } int32_t ComputeGraph::read_symint(const ValueRef idx) { @@ -951,6 +983,12 @@ void ComputeGraph::execute() { } execute_count_++; + + // Clear the set of updated values at the end of inference + updated_values_.clear(); + + // Reset the re-encoding flag at the end of inference + requires_reencode_ = false; } void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) { @@ -968,21 +1006,30 @@ void ComputeGraph::resize_input( const int64_t idx, const std::vector& new_sizes) { IOValueRef io_val = inputs_.at(idx); - get_tensor(io_val.value)->virtual_resize(new_sizes); + virtual_resize(io_val.value, new_sizes); + updated_values_.insert(io_val.staging); } void ComputeGraph::virtual_resize( const ValueRef idx, const std::vector& new_sizes) { - get_tensor(idx)->virtual_resize(new_sizes); + std::vector cur_sizes = sizes_of(idx); + if (cur_sizes != new_sizes) { + get_tensor(idx)->virtual_resize(new_sizes); + // Track that this ValueRef was updated + updated_values_.insert(idx); + } } void ComputeGraph::propagate_resize() { for (std::unique_ptr& node : execute_nodes_) { node->trigger_resize(this); } - // Only re-encode on resize if dynamic shapes are expected - if (config_.expect_dynamic_shapes) { + // A command buffer re-encode will be needed if: + // 1. Any push constant data (used for tensor metadata) was updated + // 2. Compute shader dispatch parameters (i.e. compute shader, global and + // local work group sizes) were updated + if (requires_reencode_) { clear_deferred_cmds(); } } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 3baa4df4de6..e4556a9efe6 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -196,6 +196,12 @@ class ComputeGraph final { // List of command buffers deferred for submission std::vector deferred_cmd_list_; + // Set to track which ValueRefs were updated during inference + std::unordered_set updated_values_; + + // Flag to indicate if re-encoding is required + bool requires_reencode_ = false; + protected: size_t values_in_use_ = 0; size_t execute_count_ = 0; @@ -244,6 +250,9 @@ class ComputeGraph final { return config_; } + // Check if the ComputeGraph has a value at the specified index + bool is_valid_value_idx(const ValueRef idx) const noexcept; + // // Value Extraction // @@ -427,31 +436,41 @@ class ComputeGraph final { } inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo dim_order_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorDimOrder); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorStrides); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorLogicalLimits); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel); + pc_data.set_value(idx); + return pc_data; } // @@ -948,6 +967,15 @@ class ComputeGraph final { void propagate_resize(); + // Check if a specific ValueRef (or ValueList) was updated, with recursive + // handling + bool was_value_updated(const ValueRef idx) const noexcept; + + // Set the flag to indicate that re-encoding is required + inline void set_requires_reencode() noexcept { + requires_reencode_ = true; + } + // // Miscellaneous Utilities // diff --git a/backends/vulkan/runtime/graph/containers/PushConstantData.h b/backends/vulkan/runtime/graph/containers/PushConstantData.h index 39cde4722a7..c86232983ea 100644 --- a/backends/vulkan/runtime/graph/containers/PushConstantData.h +++ b/backends/vulkan/runtime/graph/containers/PushConstantData.h @@ -10,6 +10,8 @@ #include +#include + namespace vkcompute { class ComputeGraph; @@ -33,6 +35,9 @@ class PushConstantDataInfo { }; Payload payload_; + // The value in a compute graph that this push constant data is associated + // with, if any. + ValueRef value_ = kDummyValueRef; public: explicit PushConstantDataInfo( @@ -60,6 +65,18 @@ class PushConstantDataInfo { void* dst, const uint32_t dst_offset, const uint32_t max_dst_size) const; + + inline bool is_tensor_metadata() const noexcept { + return tensorUniformData != nullptr; + } + + inline void set_value(ValueRef value) noexcept { + value_ = value; + } + + inline ValueRef value() const noexcept { + return value_; + } }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index b5644cf3dcd..898a3415b7e 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -89,4 +89,21 @@ void DispatchNode::write_push_constant_data() { } } +bool DispatchNode::trigger_resize(ComputeGraph* graph) { + const bool any_arg_updated = ExecuteNode::trigger_resize(graph); + + if (any_arg_updated) { + // If this shader uses push constants, and the tensor metadata associated + // with the push constants has changed, then the command buffer needs to be + // re-encoded since push constants cannot be updated. + for (const auto& push_constant : push_constants_) { + if (push_constant.is_tensor_metadata() && + graph->was_value_updated(push_constant.value())) { + graph->set_requires_reencode(); + } + } + } + return any_arg_updated; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index b6eb8624c26..89d24a77d6e 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -44,6 +44,8 @@ class DispatchNode : public ExecuteNode { void encode(ComputeGraph* graph) override; + bool trigger_resize(ComputeGraph* graph) override; + protected: vkapi::ShaderInfo shader_; utils::uvec3 global_workgroup_size_; diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index ea2061d3d7c..5a88bba88c9 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -41,6 +41,12 @@ DynamicDispatchNode::DynamicDispatchNode( pick_global_wg_fn(&graph, shader_, args, resize_args); local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( &graph, shader_, global_workgroup_size_, args, resize_args)); + + // Calculate dispatch grid similar to Context.cpp register_shader_dispatch + wg_dispatch_grid_ = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; } DynamicDispatchNode::DynamicDispatchNode( @@ -72,21 +78,74 @@ DynamicDispatchNode::DynamicDispatchNode( pick_global_wg_fn(&graph, shader_, args, resize_args); local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( &graph, shader_, global_workgroup_size_, args, resize_args)); + // Calculate the work group grid that will be dispatched + wg_dispatch_grid_ = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; } -void DynamicDispatchNode::encode(ComputeGraph* graph) { +bool DynamicDispatchNode::trigger_resize(ComputeGraph* graph) { + // DispatchNode::trigger_resize() will return true if any of the values + // participating in this operation were updated. + const bool any_arg_updated = DispatchNode::trigger_resize(graph); + // Only re-compute the shader, global workgroup size, and local workgroup size + // if any of the values participating in this operation were updated. + // Otherwise, assume that these will not have changed. + if (!any_arg_updated) { + return false; + } + + // Indicates if the shader dispatch should be changed since the last time the + // command buffer was encoded. + bool dispatch_params_changed = false; + if (pick_shader_fn_) { - shader_ = pick_shader_fn_(graph, args_, resize_args_); + vkapi::ShaderInfo new_shader = pick_shader_fn_(graph, args_, resize_args_); + // Compare shader kernel names as a proxy for shader equality + if (shader_.kernel_name != new_shader.kernel_name) { + shader_ = new_shader; + dispatch_params_changed = true; + } } if (pick_global_wg_fn_) { + // Note that if global workgroup size changes, then the dispatch params + // may not actually be different. The actual value to check is the + // work group grid size that will be dispatched, which is calculated + // below. global_workgroup_size_ = pick_global_wg_fn_(graph, shader_, args_, resize_args_); } if (pick_local_wg_fn_) { - local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_( - graph, shader_, global_workgroup_size_, args_, resize_args_)); + utils::uvec3 new_local_wg_uvec3 = pick_local_wg_fn_( + graph, shader_, global_workgroup_size_, args_, resize_args_); + utils::WorkgroupSize new_local_wg = + utils::WorkgroupSize(new_local_wg_uvec3); + if (local_workgroup_size_ != new_local_wg) { + local_workgroup_size_ = new_local_wg; + dispatch_params_changed = true; + } + } + + // Always recompute the new dispatch grid and check if it's different + utils::uvec3 new_wg_dispatch_grid = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; + + // Check if the new dispatch grid is different from the old one + if (wg_dispatch_grid_ != new_wg_dispatch_grid) { + dispatch_params_changed = true; } - DispatchNode::encode(graph); + wg_dispatch_grid_ = new_wg_dispatch_grid; + + // If any of the dispatch params have changed, then the command buffer must + // be re-encoded. + if (dispatch_params_changed) { + graph->set_requires_reencode(); + } + + return true; } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h index 005151272c3..d3b82968eb2 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -68,13 +68,15 @@ class DynamicDispatchNode final : public DispatchNode { ~DynamicDispatchNode() override = default; - void encode(ComputeGraph* graph) override; + bool trigger_resize(ComputeGraph* graph) override; protected: const PickShaderFn pick_shader_fn_; const PickGlobalFn pick_global_wg_fn_; const PickLocalFn pick_local_wg_fn_; + utils::uvec3 wg_dispatch_grid_{1u, 1u, 1u}; + public: operator bool() const { return shader_; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 7335ce2703b..953f15e7b4d 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include namespace vkcompute { @@ -18,4 +19,33 @@ ExecuteNode::ExecuteNode( resize_args_(resize_args), args_(args), name_(name) {} + +bool ExecuteNode::trigger_resize(ComputeGraph* graph) { + const bool any_arg_updated = was_any_arg_updated(graph); + if (resize_fn_ && any_arg_updated) { + resize_fn_(graph, args_, resize_args_); + } + return any_arg_updated; +} + +bool ExecuteNode::was_any_arg_updated(const ComputeGraph* const graph) const { + // Check all ValueRefs in ArgGroups + for (const auto& arg_group : args_) { + for (const auto& value_ref : arg_group.refs) { + if (graph->was_value_updated(value_ref)) { + return true; + } + } + } + + // Check all ValueRefs in resize_args + for (const auto& value_ref : resize_args_) { + if (graph->was_value_updated(value_ref)) { + return true; + } + } + + return false; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 4ea1ba57796..323036cef90 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -69,11 +69,9 @@ class ExecuteNode { (void)graph; } - virtual inline void trigger_resize(ComputeGraph* graph) { - if (resize_fn_ != nullptr) { - resize_fn_(graph, args_, resize_args_); - } - } + virtual bool trigger_resize(ComputeGraph* graph); + + bool was_any_arg_updated(const ComputeGraph* const graph) const; inline void set_node_id(uint32_t node_id) { node_id_ = node_id; diff --git a/backends/vulkan/runtime/utils/VecUtils.h b/backends/vulkan/runtime/utils/VecUtils.h index 6d2e8c63bb9..d84eb54d2b9 100644 --- a/backends/vulkan/runtime/utils/VecUtils.h +++ b/backends/vulkan/runtime/utils/VecUtils.h @@ -275,6 +275,19 @@ struct vec final { VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!"); return data[i]; } + + bool operator==(const vec& other) const { + for (uint32_t i = 0; i < N; ++i) { + if (data[i] != other.data[i]) { + return false; + } + } + return true; + } + + bool operator!=(const vec& other) const { + return !(*this == other); + } }; } // namespace detail @@ -527,6 +540,16 @@ class WorkgroupSize final { inline constexpr uint32_t operator[](const int idx) const { return (val >> (11 * idx)) & 0x7ffu; } + + // Equality operator + bool operator==(const WorkgroupSize& other) const { + return val == other.val; + } + + // Inequality operator (optional, for completeness) + bool operator!=(const WorkgroupSize& other) const { + return !(*this == other); + } }; } // namespace utils