Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ void Context::register_shader_dispatch(
const vkapi::DescriptorSet& descriptors,
vkapi::PipelineBarrier& pipeline_barrier,
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& global_workgroup_size) {
const utils::uvec3& global_workgroup_size,
const void* push_constants_data,
const uint32_t push_constants_size) {
// Adjust the global workgroup size based on the output tile size
uint32_t global_wg_w = utils::div_up(
global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]);
Expand All @@ -145,6 +147,15 @@ void Context::register_shader_dispatch(
cmd_.bind_descriptors(descriptors.get_bind_handle());
cmd_.insert_barrier(pipeline_barrier);

if (push_constants_size > 0 && push_constants_data != nullptr) {
const VkDescriptorSetLayout shader_layout =
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
const VkPipelineLayout pipeline_layout =
pipeline_layout_cache().retrieve(shader_layout);
cmd_.set_push_constants(
pipeline_layout, push_constants_data, push_constants_size);
}

cmd_.dispatch(effective_global_wg);
}

Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ class Context final {
const vkapi::DescriptorSet&,
vkapi::PipelineBarrier&,
const vkapi::ShaderInfo&,
const utils::uvec3&);
const utils::uvec3&,
const void* = nullptr,
const uint32_t = 0);

void register_blit(
vkapi::PipelineBarrier&,
Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/vk_api/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
state_ = CommandBuffer::State::DESCRIPTORS_BOUND;
}

void CommandBuffer::set_push_constants(
VkPipelineLayout pipeline_layout,
const void* push_constants_data,
uint32_t push_constants_size) {
if (push_constants_data != nullptr && push_constants_size > 0) {
vkCmdPushConstants(
handle_,
pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT,
0,
push_constants_size,
push_constants_data);
}
}

void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) {
VK_CHECK_COND(
state_ == CommandBuffer::State::DESCRIPTORS_BOUND ||
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/vk_api/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class CommandBuffer final {

void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3);
void bind_descriptors(VkDescriptorSet);
void set_push_constants(VkPipelineLayout, const void*, uint32_t);

void insert_barrier(PipelineBarrier& pipeline_barrier);
void dispatch(const utils::uvec3&);
Expand Down