Skip to content

Commit f6a243f

Browse files
committed
[ET-VK] Using shared memory offsetting in conv2d pw and saving ivec3 pos instead of ivec2 to improve performance.
Pull Request resolved: #7817 This diff changes conv2d pw op shader to offset shared memory based on thread local index to improve performance. Change also saves pos as ivec3 pos instead of ivec2. ghstack-source-id: 262858897 @exported-using-ghexport Differential Revision: [D68400786](https://our.internmc.facebook.com/intern/diff/D68400786/)
1 parent 7a26f6b commit f6a243f

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#define TILE_SIZE_X ${TILE_SIZE_X}
1616
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define LOCAL_WG_SIZE 64
1718

1819
#define op(X, A, B) ${OPERATOR}
1920

@@ -42,10 +43,12 @@ layout(push_constant) uniform restrict Block {
4243

4344
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4445

46+
// macro to offset shared memory access index. Padding position index by 1 offset per 16 positions avoidd bank access conflict and thus improves performance.
47+
#define offset_pos_index(index) (index + ((index) >> 4))
48+
4549
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
4650
// 64 is the number of threads in the local wg
47-
$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
48-
shared ivec2 pos_shared[${num_shared}];
51+
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
4952

5053
/*
5154
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
@@ -54,7 +57,7 @@ shared ivec2 pos_shared[${num_shared}];
5457
*/
5558
void main() {
5659
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
57-
const uint shared_mem_stride = 64;
60+
const uint shared_mem_stride = LOCAL_WG_SIZE;
5861

5962
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
6063
const ivec3 gpos = ivec3(
@@ -72,7 +75,7 @@ void main() {
7275
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
7376
for (int x = 0; x < TILE_SIZE_X; ++x) {
7477
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
75-
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
78+
pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
7679
i++;
7780
}
7881
}
@@ -152,9 +155,10 @@ void main() {
152155
}
153156

154157
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
155-
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156-
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
157-
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
158+
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
159+
const ivec3 pos = pos_shared[offset_pos_index(index)];
160+
if (all(lessThan(pos, out_limits.xyz))) {
161+
imageStore(t_out, pos, op(sum[i], out_min, out_max));
158162
}
159163
}
160164
}

0 commit comments

Comments
 (0)