Skip to content

Commit 88af067

Browse files
authored
[ET-VK] Create Pipeline layouts with push constant ranges when required
Differential Revision: D67770793 Pull Request resolved: #7479
1 parent bc0facd commit 88af067

File tree

6 files changed

+60
-32
lines changed

6 files changed

+60
-32
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,13 @@ void Context::report_shader_dispatch_end() {
9090
vkapi::DescriptorSet Context::get_descriptor_set(
9191
const vkapi::ShaderInfo& shader_descriptor,
9292
const utils::uvec3& local_workgroup_size,
93-
const vkapi::SpecVarList& additional_constants) {
93+
const vkapi::SpecVarList& additional_constants,
94+
const uint32_t push_constants_size) {
9495
VkDescriptorSetLayout shader_layout =
9596
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
9697

9798
VkPipelineLayout pipeline_layout =
98-
pipeline_layout_cache().retrieve(shader_layout);
99+
pipeline_layout_cache().retrieve(shader_layout, push_constants_size);
99100

100101
vkapi::SpecVarList spec_constants = {
101102
SV(local_workgroup_size[0u]),
@@ -105,7 +106,7 @@ vkapi::DescriptorSet Context::get_descriptor_set(
105106
spec_constants.append(additional_constants);
106107

107108
VkPipeline pipeline = pipeline_cache().retrieve(
108-
{pipeline_layout_cache().retrieve(shader_layout),
109+
{pipeline_layout_cache().retrieve(shader_layout, push_constants_size),
109110
shader_cache().retrieve(shader_descriptor),
110111
spec_constants});
111112

@@ -151,7 +152,7 @@ void Context::register_shader_dispatch(
151152
const VkDescriptorSetLayout shader_layout =
152153
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
153154
const VkPipelineLayout pipeline_layout =
154-
pipeline_layout_cache().retrieve(shader_layout);
155+
pipeline_layout_cache().retrieve(shader_layout, push_constants_size);
155156
cmd_.set_push_constants(
156157
pipeline_layout, push_constants_data, push_constants_size);
157158
}

backends/vulkan/runtime/api/Context.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,13 @@ class Context final {
188188
vkapi::DescriptorSet get_descriptor_set(
189189
const vkapi::ShaderInfo&,
190190
const utils::uvec3&,
191-
const vkapi::SpecVarList&);
191+
const vkapi::SpecVarList&,
192+
const uint32_t push_constants_size);
192193

193194
inline vkapi::DescriptorSet get_descriptor_set(
194195
const vkapi::ShaderInfo& shader_descriptor,
195196
const utils::uvec3& local_work_group_size) {
196-
return get_descriptor_set(shader_descriptor, local_work_group_size, {});
197+
return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u);
197198
}
198199

199200
void register_shader_dispatch(
@@ -333,8 +334,10 @@ inline bool Context::submit_compute_job(
333334
dispatch_id);
334335

335336
// Factor out template parameter independent code to minimize code bloat.
337+
// Note that push constants are not exposed yet via this API, therefore the
338+
// push constants size is assumed to be 0.
336339
vkapi::DescriptorSet descriptor_set = get_descriptor_set(
337-
shader, local_work_group_size, specialization_constants);
340+
shader, local_work_group_size, specialization_constants, 0u);
338341

339342
detail::bind(
340343
descriptor_set,

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,31 @@ void DispatchNode::encode(ComputeGraph* graph) {
6060

6161
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
6262

63+
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
64+
uint32_t push_constants_offset = 0;
65+
66+
for (const auto& push_constant : push_constants_) {
67+
push_constants_offset += push_constant.write(
68+
push_constants_data.data(),
69+
push_constants_offset,
70+
kMaxPushConstantSize);
71+
}
72+
6373
context->report_shader_dispatch_start(
6474
shader_.kernel_name,
6575
global_workgroup_size_,
6676
local_workgroup_size_,
6777
node_id_);
6878

69-
vkapi::DescriptorSet descriptor_set =
70-
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
79+
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
80+
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
7181

7282
uint32_t idx = 0;
7383
idx = bind_values_to_descriptor_set(
7484
graph, args_, pipeline_barrier, descriptor_set, idx);
7585

7686
bind_params_to_descriptor_set(params_, descriptor_set, idx);
7787

78-
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
79-
uint32_t push_constants_offset = 0;
80-
81-
for (const auto& push_constant : push_constants_) {
82-
push_constants_offset += push_constant.write(
83-
push_constants_data.data(),
84-
push_constants_offset,
85-
kMaxPushConstantSize);
86-
}
8788
context->register_shader_dispatch(
8889
descriptor_set,
8990
pipeline_barrier,

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
7575

7676
{
7777
vkapi::PipelineBarrier pipeline_barrier{};
78-
vkapi::DescriptorSet descriptor_set =
79-
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
78+
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
79+
shader_, local_workgroup_size_, spec_vars_, 0u);
8080

8181
uint32_t idx = 0;
8282
bind_tensor_to_descriptor_set(

backends/vulkan/runtime/vk_api/Pipeline.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,29 @@ bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
205205

206206
PipelineLayout::PipelineLayout(
207207
VkDevice device,
208-
VkDescriptorSetLayout descriptor_layout)
208+
VkDescriptorSetLayout descriptor_layout,
209+
const uint32_t push_constants_size)
209210
: device_(device), handle_{VK_NULL_HANDLE} {
210-
// TODO: Enable push constants
211+
VkPushConstantRange pc_range{
212+
VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags
213+
0u, // offset
214+
push_constants_size, // size
215+
};
216+
uint32_t num_push_constants = 0u;
217+
VkPushConstantRange* pc_ranges_ptr = nullptr;
218+
if (push_constants_size > 0u) {
219+
num_push_constants = 1u;
220+
pc_ranges_ptr = &pc_range;
221+
}
222+
211223
const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
212224
VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
213225
nullptr, // pNext
214226
0u, // flags
215227
1u, // setLayoutCount
216228
&descriptor_layout, // pSetLayouts
217-
0u, // pushConstantRangeCount
218-
nullptr, // pPushConstantRanges
229+
num_push_constants, // pushConstantRangeCount
230+
pc_ranges_ptr, // pPushConstantRanges
219231
};
220232

221233
VK_CHECK(vkCreatePipelineLayout(
@@ -344,12 +356,19 @@ PipelineLayoutCache::~PipelineLayoutCache() {
344356
}
345357

346358
VkPipelineLayout PipelineLayoutCache::retrieve(
347-
const PipelineLayoutCache::Key& key) {
359+
const VkDescriptorSetLayout layout,
360+
const uint32_t push_constants_size) {
361+
PipelineLayoutCache::Key key{layout, push_constants_size};
348362
std::lock_guard<std::mutex> lock(cache_mutex_);
349363

350364
auto it = cache_.find(key);
351365
if (cache_.cend() == it) {
352-
it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first;
366+
it = cache_
367+
.insert(
368+
{key,
369+
PipelineLayoutCache::Value(
370+
device_, layout, push_constants_size)})
371+
.first;
353372
}
354373

355374
return it->second.handle();

backends/vulkan/runtime/vk_api/Pipeline.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags);
121121

122122
class PipelineLayout final {
123123
public:
124-
explicit PipelineLayout(VkDevice, VkDescriptorSetLayout);
124+
explicit PipelineLayout(VkDevice, VkDescriptorSetLayout, const uint32_t);
125125

126126
PipelineLayout(const PipelineLayout&) = delete;
127127
PipelineLayout& operator=(const PipelineLayout&) = delete;
@@ -193,13 +193,17 @@ class PipelineLayoutCache final {
193193
PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete;
194194

195195
~PipelineLayoutCache();
196-
197-
using Key = VkDescriptorSetLayout;
196+
using Key = std::pair<VkDescriptorSetLayout, uint32_t>;
198197
using Value = PipelineLayout;
199198

200199
struct Hasher {
201-
inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const {
202-
return std::hash<VkDescriptorSetLayout>()(descriptor_layout);
200+
inline size_t operator()(
201+
std::pair<VkDescriptorSetLayout, uint32_t> key) const {
202+
size_t seed = 0;
203+
seed = utils::hash_combine(
204+
seed, std::hash<VkDescriptorSetLayout>()(key.first));
205+
seed = utils::hash_combine(seed, std::hash<uint32_t>()(key.second));
206+
return seed;
203207
}
204208
};
205209

@@ -212,7 +216,7 @@ class PipelineLayoutCache final {
212216
std::unordered_map<Key, Value, Hasher> cache_;
213217

214218
public:
215-
VkPipelineLayout retrieve(const Key&);
219+
VkPipelineLayout retrieve(const VkDescriptorSetLayout, const uint32_t);
216220
void purge();
217221
};
218222

0 commit comments

Comments
 (0)