Skip to content

Commit f93a361

Browse files
committed
[ET-VK] Using shared memory offsetting in conv2d pw and saving ivec3 pos instead of ivec2 to improve performance.
Pull Request resolved: pytorch/executorch#7817 This diff changes conv2d pw op shader to offset shared memory based on thread local index to improve performance. Change also saves pos as ivec3 pos instead of ivec2. ghstack-source-id: 262139413 @exported-using-ghexport Differential Revision: [D68400786](https://our.internmc.facebook.com/intern/diff/D68400786/)
1 parent 91519f9 commit f93a361

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#define TILE_SIZE_X ${TILE_SIZE_X}
1616
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define LOCAL_WG_SIZE 64
1718

1819
#define op(X, A, B) ${OPERATOR}
1920

@@ -42,10 +43,11 @@ layout(push_constant) uniform restrict Block {
4243

4344
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4445

46+
#define resolve_pos_index(offset) (offset + ((offset) >> 4))
47+
4548
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
4649
// 64 is the number of threads in the local wg
47-
$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
48-
shared ivec2 pos_shared[${num_shared}];
50+
shared ivec3 pos_shared[resolve_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
4951

5052
/*
5153
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
@@ -54,7 +56,7 @@ shared ivec2 pos_shared[${num_shared}];
5456
*/
5557
void main() {
5658
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
57-
const uint shared_mem_stride = 64;
59+
const uint shared_mem_stride = LOCAL_WG_SIZE;
5860

5961
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
6062
const ivec3 gpos = ivec3(
@@ -72,7 +74,7 @@ void main() {
7274
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
7375
for (int x = 0; x < TILE_SIZE_X; ++x) {
7476
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
75-
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
77+
pos_shared[resolve_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
7678
i++;
7779
}
7880
}
@@ -152,9 +154,10 @@ void main() {
152154
}
153155

154156
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
155-
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156-
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
157-
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
157+
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
158+
const ivec3 pos = pos_shared[resolve_pos_index(index)];
159+
if (all(lessThan(pos, out_limits.xyz))) {
160+
imageStore(t_out, pos, op(sum[i], out_min, out_max));
158161
}
159162
}
160163
}

0 commit comments

Comments
 (0)