Skip to content

Commit 7b410aa

Browse files
authored
[ET-VK] Replace Uniform buffers with push constants for permute op
Differential Revision: D66890825 Pull Request resolved: #7231
1 parent f3d5fec commit 7b410aa

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

backends/vulkan/runtime/graph/ops/glsl/permute.glsl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,9 @@ layout(std430) buffer;
1919
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
2020
layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in;
2121

22-
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
23-
ivec3 out_limits;
24-
};
25-
26-
layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {
22+
layout(push_constant) uniform PRECISION restrict Block {
23+
ivec4 out_limits;
2724
ivec4 sizes;
28-
};
29-
30-
layout(set = 0, binding = 4) uniform PRECISION restrict Block {
3125
// output dims
3226
ivec4 out_ndims;
3327
// x = output channels aligned to 4, y = input channels aligned to 4
@@ -41,7 +35,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4135
void main() {
4236
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);
4337

44-
if (any(greaterThanEqual(pos, out_limits))) {
38+
if (any(greaterThanEqual(pos, out_limits.xyz))) {
4539
return;
4640
}
4741

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,7 @@ void add_permute_node(
7575
int32_t out_c_aligned = utils::align_up_4(out_channels);
7676
int32_t in_c_aligned = utils::align_up_4(in_channels);
7777

78-
const struct Block final {
79-
ivec4 out_ndims;
80-
ivec2 ch_info;
81-
} params{
82-
out_dims,
83-
{out_c_aligned, in_c_aligned},
84-
};
78+
const ivec2 ch_info = {out_c_aligned, in_c_aligned};
8579

8680
graph.execute_nodes().emplace_back(new DispatchNode(
8781
graph,
@@ -90,14 +84,16 @@ void add_permute_node(
9084
graph.create_local_wg_size(out),
9185
{{out, vkapi::MemoryAccessType::WRITE},
9286
{in, vkapi::MemoryAccessType::READ}},
93-
{t_out->logical_limits_ubo(),
94-
t_out->sizes_ubo(),
95-
graph.create_params_buffer(params)},
87+
{},
9688
// Specialization Constants
9789
{},
9890
// Resizing Logic
9991
nullptr,
100-
{}));
92+
{},
93+
{{graph.logical_limits_pc_of(out),
94+
graph.sizes_pc_of(out),
95+
PushConstantDataInfo(&out_dims, sizeof(out_dims)),
96+
PushConstantDataInfo(&ch_info, sizeof(ch_info))}}));
10197
}
10298

10399
void add_permute_node(

0 commit comments

Comments
 (0)