Skip to content

Commit 16d44b6

Browse files
authored
[ET-VK] Using shared memory offsetting in conv2d pw and saving ivec3 pos instead of ivec2 to improve performance.
Differential Revision: D68400786 Pull Request resolved: #7817
1 parent 130870a commit 16d44b6

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#define TILE_SIZE_X ${TILE_SIZE_X}
1616
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define LOCAL_WG_SIZE 64
1718

1819
#define op(X, A, B) ${OPERATOR}
1920

@@ -42,10 +43,10 @@ layout(push_constant) uniform restrict Block {
4243

4344
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4445

45-
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
46-
// 64 is the number of threads in the local wg
47-
$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
48-
shared ivec2 pos_shared[${num_shared}];
46+
// For performance improvement, reduce register usage by caching positions in shared memory.
47+
// Offset index by 1 every 16 points to avoid bank access conflict.
48+
#define offset_pos_index(index) (index + ((index) >> 4))
49+
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
4950

5051
/*
5152
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
@@ -54,7 +55,7 @@ shared ivec2 pos_shared[${num_shared}];
5455
*/
5556
void main() {
5657
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
57-
const uint shared_mem_stride = 64;
58+
const uint shared_mem_stride = LOCAL_WG_SIZE;
5859

5960
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
6061
const ivec3 gpos = ivec3(
@@ -72,7 +73,7 @@ void main() {
7273
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
7374
for (int x = 0; x < TILE_SIZE_X; ++x) {
7475
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
75-
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
76+
pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
7677
i++;
7778
}
7879
}
@@ -152,9 +153,10 @@ void main() {
152153
}
153154

154155
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
155-
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156-
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
157-
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
156+
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
157+
const ivec3 pos = pos_shared[offset_pos_index(index)];
158+
if (all(lessThan(pos, out_limits.xyz))) {
159+
imageStore(t_out, pos, op(sum[i], out_min, out_max));
158160
}
159161
}
160162
}

0 commit comments

Comments
 (0)