Skip to content

Commit 88a8aad

Browse files
committed
[ET-VK] Add PushConstantDataInfo and vector to hold push constants data in DispatchNode.
This diff adds a new class called `PushConstantDataInfo` to the `DispatchNode` class in the Vulkan backend for Executorch. This class represents a push constant data entry, which can either be a shared pointer to a tensor's uniform data with an attribute or data with a maximum size of 16 bytes. The `write` method is also added to this class, which writes the data to a destination buffer. Differential Revision: [D66796049](https://our.internmc.facebook.com/intern/diff/D66796049/) ghstack-source-id: 256911523 Pull Request resolved: #7223
1 parent ef48bbf commit 88a8aad

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@
1414

1515
namespace vkcompute {
1616

17+
uint32_t PushConstantDataInfo::write(
18+
void* dst,
19+
const uint32_t dst_offset,
20+
const uint32_t max_dst_size) const {
21+
if (tensorUniformData != nullptr) {
22+
return tensorUniformData->write_attribute(
23+
dst, dst_offset, max_dst_size, payload_.attr);
24+
}
25+
26+
VK_CHECK_COND(
27+
(dst_offset + payload_.dataSize) <= max_dst_size,
28+
"Attempting to write push constant data outside data boundary.");
29+
memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize);
30+
return payload_.dataSize;
31+
}
32+
1733
DispatchNode::DispatchNode(
1834
ComputeGraph& graph,
1935
const vkapi::ShaderInfo& shader,
@@ -23,13 +39,15 @@ DispatchNode::DispatchNode(
2339
const vkapi::ParamsBindList& params,
2440
const vkapi::SpecVarList& spec_vars,
2541
const ResizeFunction& resize_fn,
26-
const std::vector<ValueRef>& resize_args)
42+
const std::vector<ValueRef>& resize_args,
43+
const std::vector<PushConstantDataInfo>& push_constants)
2744
: ExecuteNode(resize_fn, resize_args, args, shader.kernel_name),
2845
shader_(shader),
2946
global_workgroup_size_(global_workgroup_size),
3047
local_workgroup_size_(local_workgroup_size),
3148
params_(params),
32-
spec_vars_(spec_vars) {
49+
spec_vars_(spec_vars),
50+
push_constants_(push_constants) {
3351
graph.update_descriptor_counts(shader, /*execute = */ true);
3452
}
3553

@@ -57,8 +75,20 @@ void DispatchNode::encode(ComputeGraph* graph) {
5775

5876
bind_params_to_descriptor_set(params_, descriptor_set, idx);
5977

78+
uint8_t push_constants_data[128];
79+
uint32_t push_constants_offset = 0;
80+
81+
for (const auto& push_constant : push_constants_) {
82+
push_constants_offset +=
83+
push_constant.write(push_constants_data, push_constants_offset, 128);
84+
}
6085
context->register_shader_dispatch(
61-
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
86+
descriptor_set,
87+
pipeline_barrier,
88+
shader_,
89+
global_workgroup_size_,
90+
push_constants_data,
91+
push_constants_offset);
6292

6393
context->report_shader_dispatch_end();
6494
}

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,50 @@ namespace vkcompute {
1818

1919
class ComputeGraph;
2020

21+
/*
22+
* Represents a push constant data entry
23+
* Which is either shared pointer to a tensor's uniform data with an attribute
24+
* Or data with a maximum size of 16 bytes
25+
*/
26+
class PushConstantDataInfo {
27+
std::shared_ptr<api::vTensor::UniformData> tensorUniformData;
28+
union Payload {
29+
struct {
30+
api::vTensor::Attribute attr;
31+
};
32+
struct {
33+
uint8_t data[16];
34+
uint32_t dataSize;
35+
};
36+
};
37+
38+
Payload payload_;
39+
40+
public:
41+
explicit PushConstantDataInfo(
42+
const std::shared_ptr<api::vTensor::UniformData>& tensorUniformData,
43+
api::vTensor::Attribute attr)
44+
: tensorUniformData(tensorUniformData) {
45+
payload_.attr = attr;
46+
}
47+
48+
explicit PushConstantDataInfo(const void* data, uint32_t dataLen)
49+
: tensorUniformData(nullptr) {
50+
VK_CHECK_COND(
51+
dataLen <= 16, "Single push constant data size must be <= 16 bytes");
52+
payload_.dataSize = dataLen;
53+
memcpy(payload_.data, data, payload_.dataSize);
54+
}
55+
56+
/*
57+
* Function writes push constant data to the destination buffer
58+
*/
59+
uint32_t write(
60+
void* dst,
61+
const uint32_t dst_offset,
62+
const uint32_t max_dst_size) const;
63+
};
64+
2165
/*
2266
* Represents a single shader execution op in a ML model.
2367
*/
@@ -34,7 +78,8 @@ class DispatchNode final : public ExecuteNode {
3478
const vkapi::ParamsBindList& params,
3579
const vkapi::SpecVarList& spec_vars = {},
3680
const ResizeFunction& resize_fn = nullptr,
37-
const std::vector<ValueRef>& resize_args = {});
81+
const std::vector<ValueRef>& resize_args = {},
82+
const std::vector<PushConstantDataInfo>& push_constants = {});
3883

3984
~DispatchNode() override = default;
4085

@@ -46,6 +91,7 @@ class DispatchNode final : public ExecuteNode {
4691
const utils::uvec3 local_workgroup_size_;
4792
const vkapi::ParamsBindList params_;
4893
const vkapi::SpecVarList spec_vars_;
94+
const std::vector<PushConstantDataInfo> push_constants_;
4995

5096
public:
5197
operator bool() const {

0 commit comments

Comments
 (0)