From 9c199808987df42afdca986ae7d6903fc339ec0c Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 23 May 2025 18:09:34 -0700 Subject: [PATCH] [ET-VK] Fully De vectorise conv2d pw shader to improve perf. This improves the performance of the conv2d pw shader by fully de-vectorizing it. The optimization involved replacing the `ivec3 pos` array with a plain `int pos` array to store the position values. The `x` and `y` coordinates are now stored in separate elements of the array instead of being stored together in an `ivec3`. This change allows for more efficient memory access and computation. Differential Revision: [D75335802](https://our.internmc.facebook.com/intern/diff/D75335802/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/conv2d_pw.glsl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index e44a41fc9bc..ed07979afc0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -64,10 +64,11 @@ void main() { // +--------+--------+ // | pos[2] | pos[3] | // +--------+--------+ - ivec3 pos[TILE_SIZE_X * TILE_SIZE_Y]; + int pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) { for (int x = 0; x < TILE_SIZE_X; ++x) { - pos[i] = ivec3(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y, gpos.z); + pos[i * 2] = gpos.x * TILE_SIZE_X + x; + pos[i * 2 + 1] = gpos.y * TILE_SIZE_Y + y; i++; } } @@ -75,9 +76,10 @@ void main() { // Compute the index of the input texture that needs to be loaded for each // output position. Note that negative indices can be produced indicating that // the top-left element is in a region added by padding. - ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y]; + int ipos[TILE_SIZE_X * TILE_SIZE_Y * 2]; for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - ipos[i] = pos[i].xy * stride - padding; + ipos[i * 2] = pos[i * 2] * stride.x - padding.x; + ipos[i * 2 + 1] = pos[i * 2 + 1] * stride.y - padding.y; } // Final output array where each element is a tensor value. @@ -112,7 +114,7 @@ void main() { } for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0); + const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i * 2], ipos[i * 2 + 1], z4), 0); // Load the input texel into an array float tex_values[4]; tex_values[0] = in_tex.x; @@ -163,8 +165,9 @@ void main() { } for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - if (all(lessThan(pos[i], out_limits.xyz))) { - imageStore(t_out, pos[i], op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max)); + const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], gpos.z); + if (all(lessThan(pos_l, out_limits.xyz))) { + imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max)); } } }