diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 8be4553b060..73b726bd32e 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -583,13 +583,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } - // propagate_resize() will re-encode the command buffer so that push - // constants are updated and DynamicDispatchNode can update the compute - // shader, global workgroup size, and local workgroup size to perform the - // model inference. - if (should_propagate_resize || - (compute_graph->graphconfig().expect_dynamic_shapes && - compute_graph->execute_count() == 0u)) { + if (should_propagate_resize) { compute_graph->propagate_resize(); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 3b9061701e6..acd20c9ee44 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -206,6 +206,29 @@ utils::StorageType ComputeGraph::suggested_storage_type() { return utils::kTexture3D; } +bool ComputeGraph::was_value_updated(const ValueRef idx) const noexcept { + if (!is_valid_value_idx(idx)) { + return false; + } + + // Check if this ValueRef itself was updated + if (updated_values_.find(idx) != updated_values_.end()) { + return true; + } + + // If this is a ValueList, check each ValueRef in the list + if (val_is_value_list(idx)) { + const auto& value_list = values_.at(idx).toConstValueList(); + for (const auto& nested_idx : value_list) { + if (was_value_updated(nested_idx)) { + return true; + } + } + } + + return false; +} + utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout( const std::vector& sizes) { if (config_.enable_memory_layout_override) { @@ -236,6 +259,10 @@ void ComputeGraph::check_no_active_value_ptrs() { "invalidated."); } +bool ComputeGraph::is_valid_value_idx(const ValueRef idx) const noexcept { + return idx >= 0 && idx < static_cast(values_.size()); +} + std::vector ComputeGraph::sizes_of(const ValueRef idx) const { const Value& val = values_.at(idx); if (val.isTensor()) { @@ -569,7 +596,12 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( } void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { - get_symint(idx)->set(val); + int32_t cur_val = read_symint(idx); + if (cur_val != val) { + get_symint(idx)->set(val); + // Track that this ValueRef was updated + updated_values_.insert(idx); + } } int32_t ComputeGraph::read_symint(const ValueRef idx) { @@ -951,6 +983,12 @@ void ComputeGraph::execute() { } execute_count_++; + + // Clear the set of updated values at the end of inference + updated_values_.clear(); + + // Reset the re-encoding flag at the end of inference + requires_reencode_ = false; } void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) { @@ -968,21 +1006,30 @@ void ComputeGraph::resize_input( const int64_t idx, const std::vector& new_sizes) { IOValueRef io_val = inputs_.at(idx); - get_tensor(io_val.value)->virtual_resize(new_sizes); + virtual_resize(io_val.value, new_sizes); + updated_values_.insert(io_val.staging); } void ComputeGraph::virtual_resize( const ValueRef idx, const std::vector& new_sizes) { - get_tensor(idx)->virtual_resize(new_sizes); + std::vector cur_sizes = sizes_of(idx); + if (cur_sizes != new_sizes) { + get_tensor(idx)->virtual_resize(new_sizes); + // Track that this ValueRef was updated + updated_values_.insert(idx); + } } void ComputeGraph::propagate_resize() { for (std::unique_ptr& node : execute_nodes_) { node->trigger_resize(this); } - // Only re-encode on resize if dynamic shapes are expected - if (config_.expect_dynamic_shapes) { + // A command buffer re-encode will be needed if: + // 1. Any push constant data (used for tensor metadata) was updated + // 2. Compute shader dispatch parameters (i.e. compute shader, global and + // local work group sizes) were updated + if (requires_reencode_) { clear_deferred_cmds(); } } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 3baa4df4de6..e4556a9efe6 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -196,6 +196,12 @@ class ComputeGraph final { // List of command buffers deferred for submission std::vector deferred_cmd_list_; + // Set to track which ValueRefs were updated during inference + std::unordered_set updated_values_; + + // Flag to indicate if re-encoding is required + bool requires_reencode_ = false; + protected: size_t values_in_use_ = 0; size_t execute_count_ = 0; @@ -244,6 +250,9 @@ class ComputeGraph final { return config_; } + // Check if the ComputeGraph has a value at the specified index + bool is_valid_value_idx(const ValueRef idx) const noexcept; + // // Value Extraction // @@ -427,31 +436,41 @@ class ComputeGraph final { } inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo dim_order_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorDimOrder); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorStrides); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorLogicalLimits); + pc_data.set_value(idx); + return pc_data; } inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const { - return PushConstantDataInfo( + PushConstantDataInfo pc_data = PushConstantDataInfo( values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel); + pc_data.set_value(idx); + return pc_data; } // @@ -948,6 +967,15 @@ class ComputeGraph final { void propagate_resize(); + // Check if a specific ValueRef (or ValueList) was updated, with recursive + // handling + bool was_value_updated(const ValueRef idx) const noexcept; + + // Set the flag to indicate that re-encoding is required + inline void set_requires_reencode() noexcept { + requires_reencode_ = true; + } + // // Miscellaneous Utilities // diff --git a/backends/vulkan/runtime/graph/containers/PushConstantData.h b/backends/vulkan/runtime/graph/containers/PushConstantData.h index 39cde4722a7..c86232983ea 100644 --- a/backends/vulkan/runtime/graph/containers/PushConstantData.h +++ b/backends/vulkan/runtime/graph/containers/PushConstantData.h @@ -10,6 +10,8 @@ #include +#include + namespace vkcompute { class ComputeGraph; @@ -33,6 +35,9 @@ class PushConstantDataInfo { }; Payload payload_; + // The value in a compute graph that this push constant data is associated + // with, if any. + ValueRef value_ = kDummyValueRef; public: explicit PushConstantDataInfo( @@ -60,6 +65,18 @@ class PushConstantDataInfo { void* dst, const uint32_t dst_offset, const uint32_t max_dst_size) const; + + inline bool is_tensor_metadata() const noexcept { + return tensorUniformData != nullptr; + } + + inline void set_value(ValueRef value) noexcept { + value_ = value; + } + + inline ValueRef value() const noexcept { + return value_; + } }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index b5644cf3dcd..898a3415b7e 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -89,4 +89,21 @@ void DispatchNode::write_push_constant_data() { } } +bool DispatchNode::trigger_resize(ComputeGraph* graph) { + const bool any_arg_updated = ExecuteNode::trigger_resize(graph); + + if (any_arg_updated) { + // If this shader uses push constants, and the tensor metadata associated + // with the push constants has changed, then the command buffer needs to be + // re-encoded since push constants cannot be updated. + for (const auto& push_constant : push_constants_) { + if (push_constant.is_tensor_metadata() && + graph->was_value_updated(push_constant.value())) { + graph->set_requires_reencode(); + } + } + } + return any_arg_updated; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index b6eb8624c26..89d24a77d6e 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -44,6 +44,8 @@ class DispatchNode : public ExecuteNode { void encode(ComputeGraph* graph) override; + bool trigger_resize(ComputeGraph* graph) override; + protected: vkapi::ShaderInfo shader_; utils::uvec3 global_workgroup_size_; diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index ea2061d3d7c..5a88bba88c9 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -41,6 +41,12 @@ DynamicDispatchNode::DynamicDispatchNode( pick_global_wg_fn(&graph, shader_, args, resize_args); local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( &graph, shader_, global_workgroup_size_, args, resize_args)); + + // Calculate dispatch grid similar to Context.cpp register_shader_dispatch + wg_dispatch_grid_ = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; } DynamicDispatchNode::DynamicDispatchNode( @@ -72,21 +78,74 @@ DynamicDispatchNode::DynamicDispatchNode( pick_global_wg_fn(&graph, shader_, args, resize_args); local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( &graph, shader_, global_workgroup_size_, args, resize_args)); + // Calculate the work group grid that will be dispatched + wg_dispatch_grid_ = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; } -void DynamicDispatchNode::encode(ComputeGraph* graph) { +bool DynamicDispatchNode::trigger_resize(ComputeGraph* graph) { + // DispatchNode::trigger_resize() will return true if any of the values + // participating in this operation were updated. + const bool any_arg_updated = DispatchNode::trigger_resize(graph); + // Only re-compute the shader, global workgroup size, and local workgroup size + // if any of the values participating in this operation were updated. + // Otherwise, assume that these will not have changed. + if (!any_arg_updated) { + return false; + } + + // Indicates if the shader dispatch should be changed since the last time the + // command buffer was encoded. + bool dispatch_params_changed = false; + if (pick_shader_fn_) { - shader_ = pick_shader_fn_(graph, args_, resize_args_); + vkapi::ShaderInfo new_shader = pick_shader_fn_(graph, args_, resize_args_); + // Compare shader kernel names as a proxy for shader equality + if (shader_.kernel_name != new_shader.kernel_name) { + shader_ = new_shader; + dispatch_params_changed = true; + } } if (pick_global_wg_fn_) { + // Note that if global workgroup size changes, then the dispatch params + // may not actually be different. The actual value to check is the + // work group grid size that will be dispatched, which is calculated + // below. global_workgroup_size_ = pick_global_wg_fn_(graph, shader_, args_, resize_args_); } if (pick_local_wg_fn_) { - local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_( - graph, shader_, global_workgroup_size_, args_, resize_args_)); + utils::uvec3 new_local_wg_uvec3 = pick_local_wg_fn_( + graph, shader_, global_workgroup_size_, args_, resize_args_); + utils::WorkgroupSize new_local_wg = + utils::WorkgroupSize(new_local_wg_uvec3); + if (local_workgroup_size_ != new_local_wg) { + local_workgroup_size_ = new_local_wg; + dispatch_params_changed = true; + } + } + + // Always recompute the new dispatch grid and check if it's different + utils::uvec3 new_wg_dispatch_grid = { + utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]), + utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]), + utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])}; + + // Check if the new dispatch grid is different from the old one + if (wg_dispatch_grid_ != new_wg_dispatch_grid) { + dispatch_params_changed = true; } - DispatchNode::encode(graph); + wg_dispatch_grid_ = new_wg_dispatch_grid; + + // If any of the dispatch params have changed, then the command buffer must + // be re-encoded. + if (dispatch_params_changed) { + graph->set_requires_reencode(); + } + + return true; } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h index 005151272c3..d3b82968eb2 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -68,13 +68,15 @@ class DynamicDispatchNode final : public DispatchNode { ~DynamicDispatchNode() override = default; - void encode(ComputeGraph* graph) override; + bool trigger_resize(ComputeGraph* graph) override; protected: const PickShaderFn pick_shader_fn_; const PickGlobalFn pick_global_wg_fn_; const PickLocalFn pick_local_wg_fn_; + utils::uvec3 wg_dispatch_grid_{1u, 1u, 1u}; + public: operator bool() const { return shader_; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 7335ce2703b..953f15e7b4d 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include namespace vkcompute { @@ -18,4 +19,33 @@ ExecuteNode::ExecuteNode( resize_args_(resize_args), args_(args), name_(name) {} + +bool ExecuteNode::trigger_resize(ComputeGraph* graph) { + const bool any_arg_updated = was_any_arg_updated(graph); + if (resize_fn_ && any_arg_updated) { + resize_fn_(graph, args_, resize_args_); + } + return any_arg_updated; +} + +bool ExecuteNode::was_any_arg_updated(const ComputeGraph* const graph) const { + // Check all ValueRefs in ArgGroups + for (const auto& arg_group : args_) { + for (const auto& value_ref : arg_group.refs) { + if (graph->was_value_updated(value_ref)) { + return true; + } + } + } + + // Check all ValueRefs in resize_args + for (const auto& value_ref : resize_args_) { + if (graph->was_value_updated(value_ref)) { + return true; + } + } + + return false; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 4ea1ba57796..323036cef90 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -69,11 +69,9 @@ class ExecuteNode { (void)graph; } - virtual inline void trigger_resize(ComputeGraph* graph) { - if (resize_fn_ != nullptr) { - resize_fn_(graph, args_, resize_args_); - } - } + virtual bool trigger_resize(ComputeGraph* graph); + + bool was_any_arg_updated(const ComputeGraph* const graph) const; inline void set_node_id(uint32_t node_id) { node_id_ = node_id; diff --git a/backends/vulkan/runtime/utils/VecUtils.h b/backends/vulkan/runtime/utils/VecUtils.h index 6d2e8c63bb9..d84eb54d2b9 100644 --- a/backends/vulkan/runtime/utils/VecUtils.h +++ b/backends/vulkan/runtime/utils/VecUtils.h @@ -275,6 +275,19 @@ struct vec final { VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!"); return data[i]; } + + bool operator==(const vec& other) const { + for (uint32_t i = 0; i < N; ++i) { + if (data[i] != other.data[i]) { + return false; + } + } + return true; + } + + bool operator!=(const vec& other) const { + return !(*this == other); + } }; } // namespace detail @@ -527,6 +540,16 @@ class WorkgroupSize final { inline constexpr uint32_t operator[](const int idx) const { return (val >> (11 * idx)) & 0x7ffu; } + + // Equality operator + bool operator==(const WorkgroupSize& other) const { + return val == other.val; + } + + // Inequality operator (optional, for completeness) + bool operator!=(const WorkgroupSize& other) const { + return !(*this == other); + } }; } // namespace utils