diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 5426ea4e60b..9517941f364 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -90,12 +90,13 @@ void Context::report_shader_dispatch_end() { vkapi::DescriptorSet Context::get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, const utils::uvec3& local_workgroup_size, - const vkapi::SpecVarList& additional_constants) { + const vkapi::SpecVarList& additional_constants, + const uint32_t push_constants_size) { VkDescriptorSetLayout shader_layout = shader_layout_cache().retrieve(shader_descriptor.kernel_layout); VkPipelineLayout pipeline_layout = - pipeline_layout_cache().retrieve(shader_layout); + pipeline_layout_cache().retrieve(shader_layout, push_constants_size); vkapi::SpecVarList spec_constants = { SV(local_workgroup_size[0u]), @@ -105,7 +106,7 @@ vkapi::DescriptorSet Context::get_descriptor_set( spec_constants.append(additional_constants); VkPipeline pipeline = pipeline_cache().retrieve( - {pipeline_layout_cache().retrieve(shader_layout), + {pipeline_layout_cache().retrieve(shader_layout, push_constants_size), shader_cache().retrieve(shader_descriptor), spec_constants}); @@ -151,7 +152,7 @@ void Context::register_shader_dispatch( const VkDescriptorSetLayout shader_layout = shader_layout_cache().retrieve(shader_descriptor.kernel_layout); const VkPipelineLayout pipeline_layout = - pipeline_layout_cache().retrieve(shader_layout); + pipeline_layout_cache().retrieve(shader_layout, push_constants_size); cmd_.set_push_constants( pipeline_layout, push_constants_data, push_constants_size); } diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 65f3adb511d..300fd3995dd 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -188,12 +188,13 @@ class Context final { vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo&, const utils::uvec3&, - const vkapi::SpecVarList&); + const vkapi::SpecVarList&, + const uint32_t push_constants_size); inline vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, const utils::uvec3& local_work_group_size) { - return get_descriptor_set(shader_descriptor, local_work_group_size, {}); + return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u); } void register_shader_dispatch( @@ -333,8 +334,10 @@ inline bool Context::submit_compute_job( dispatch_id); // Factor out template parameter independent code to minimize code bloat. + // Note that push constants are not exposed yet via this API, therefore the + // push constants size is assumed to be 0. vkapi::DescriptorSet descriptor_set = get_descriptor_set( - shader, local_work_group_size, specialization_constants); + shader, local_work_group_size, specialization_constants, 0u); detail::bind( descriptor_set, diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 87b4b5b5480..a163a0d7aea 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -60,14 +60,24 @@ void DispatchNode::encode(ComputeGraph* graph) { std::unique_lock cmd_lock = context->dispatch_lock(); + std::array push_constants_data; + uint32_t push_constants_offset = 0; + + for (const auto& push_constant : push_constants_) { + push_constants_offset += push_constant.write( + push_constants_data.data(), + push_constants_offset, + kMaxPushConstantSize); + } + context->report_shader_dispatch_start( shader_.kernel_name, global_workgroup_size_, local_workgroup_size_, node_id_); - vkapi::DescriptorSet descriptor_set = - context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_); + vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( + shader_, local_workgroup_size_, spec_vars_, push_constants_offset); uint32_t idx = 0; idx = bind_values_to_descriptor_set( @@ -75,15 +85,6 @@ void DispatchNode::encode(ComputeGraph* graph) { bind_params_to_descriptor_set(params_, descriptor_set, idx); - std::array push_constants_data; - uint32_t push_constants_offset = 0; - - for (const auto& push_constant : push_constants_) { - push_constants_offset += push_constant.write( - push_constants_data.data(), - push_constants_offset, - kMaxPushConstantSize); - } context->register_shader_dispatch( descriptor_set, pipeline_barrier, diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 89719fb0dd3..e27723468ab 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -75,8 +75,8 @@ void PrepackNode::encode(ComputeGraph* graph) { { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::DescriptorSet descriptor_set = - context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_); + vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( + shader_, local_workgroup_size_, spec_vars_, 0u); uint32_t idx = 0; bind_tensor_to_descriptor_set( diff --git a/backends/vulkan/runtime/vk_api/Pipeline.cpp b/backends/vulkan/runtime/vk_api/Pipeline.cpp index 49bbf083359..3856d406c24 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.cpp +++ b/backends/vulkan/runtime/vk_api/Pipeline.cpp @@ -205,17 +205,29 @@ bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) { PipelineLayout::PipelineLayout( VkDevice device, - VkDescriptorSetLayout descriptor_layout) + VkDescriptorSetLayout descriptor_layout, + const uint32_t push_constants_size) : device_(device), handle_{VK_NULL_HANDLE} { - // TODO: Enable push constants + VkPushConstantRange pc_range{ + VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags + 0u, // offset + push_constants_size, // size + }; + uint32_t num_push_constants = 0u; + VkPushConstantRange* pc_ranges_ptr = nullptr; + if (push_constants_size > 0u) { + num_push_constants = 1u; + pc_ranges_ptr = &pc_range; + } + const VkPipelineLayoutCreateInfo pipeline_layout_create_info{ VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType nullptr, // pNext 0u, // flags 1u, // setLayoutCount &descriptor_layout, // pSetLayouts - 0u, // pushConstantRangeCount - nullptr, // pPushConstantRanges + num_push_constants, // pushConstantRangeCount + pc_ranges_ptr, // pPushConstantRanges }; VK_CHECK(vkCreatePipelineLayout( @@ -344,12 +356,19 @@ PipelineLayoutCache::~PipelineLayoutCache() { } VkPipelineLayout PipelineLayoutCache::retrieve( - const PipelineLayoutCache::Key& key) { + const VkDescriptorSetLayout layout, + const uint32_t push_constants_size) { + PipelineLayoutCache::Key key{layout, push_constants_size}; std::lock_guard lock(cache_mutex_); auto it = cache_.find(key); if (cache_.cend() == it) { - it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first; + it = cache_ + .insert( + {key, + PipelineLayoutCache::Value( + device_, layout, push_constants_size)}) + .first; } return it->second.handle(); diff --git a/backends/vulkan/runtime/vk_api/Pipeline.h b/backends/vulkan/runtime/vk_api/Pipeline.h index 4f42a9bf6bb..5460a0acba7 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.h +++ b/backends/vulkan/runtime/vk_api/Pipeline.h @@ -121,7 +121,7 @@ VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags); class PipelineLayout final { public: - explicit PipelineLayout(VkDevice, VkDescriptorSetLayout); + explicit PipelineLayout(VkDevice, VkDescriptorSetLayout, const uint32_t); PipelineLayout(const PipelineLayout&) = delete; PipelineLayout& operator=(const PipelineLayout&) = delete; @@ -193,13 +193,17 @@ class PipelineLayoutCache final { PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete; ~PipelineLayoutCache(); - - using Key = VkDescriptorSetLayout; + using Key = std::pair; using Value = PipelineLayout; struct Hasher { - inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const { - return std::hash()(descriptor_layout); + inline size_t operator()( + std::pair key) const { + size_t seed = 0; + seed = utils::hash_combine( + seed, std::hash()(key.first)); + seed = utils::hash_combine(seed, std::hash()(key.second)); + return seed; } }; @@ -212,7 +216,7 @@ class PipelineLayoutCache final { std::unordered_map cache_; public: - VkPipelineLayout retrieve(const Key&); + VkPipelineLayout retrieve(const VkDescriptorSetLayout, const uint32_t); void purge(); };