Skip to content

Commit 2ceb36c

Browse files
committed
[ET-VK] Using shared memory to save position in conv2d dw output op.
Pull Request resolved: #7818 ghstack-source-id: 262139744 @exported-using-ghexport Differential Revision: [D68400890](https://our.internmc.facebook.com/intern/diff/D68400890/)
1 parent f93a361 commit 2ceb36c

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3838

3939
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4040

41+
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
42+
// 64 is the number of threads in the local wg
43+
shared ivec3 pos_shared[64];
44+
4145
/*
4246
* Computes a depthwise convolution. Each shader invocation calculates the
4347
* output at a single output location.
@@ -63,6 +67,8 @@ void main() {
6367
return;
6468
}
6569

70+
pos_shared[gl_LocalInvocationIndex] = pos;
71+
6672
// Compute the index of the top-left element of the overlay region. Negative
6773
// indices indicate that the top-left element is in a region added by padding.
6874
const ivec2 ipos = pos.xy * stride - padding;
@@ -109,18 +115,19 @@ void main() {
109115
for (int j = 0; j < TILE_SIZE; j++, kx++) {
110116
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
111117
for (int s = 0; s < BATCH_SIZE_X; s++) {
112-
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
118+
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
113119
}
114120
}
115121
}
116122
}
117123

124+
const ivec3 out_pos = pos_shared[gl_LocalInvocationIndex];
118125
for (int y = 0; y < BATCH_SIZE_Y; y++) {
119126
for (int x = 0; x < BATCH_SIZE_X; x++) {
120-
if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) {
127+
if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits))) {
121128
continue;
122129
}
123-
imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max));
130+
imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max));
124131
}
125132
}
126133
}

0 commit comments

Comments
 (0)