From c7f3dfa61e0dccda05348b4f0e65ecbe63675e24 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 29 May 2025 16:33:46 -0700 Subject: [PATCH] [ET-VK] De vectorizing sum and moving bias application to the end in conv 2d op to improve performance. This diff optimizes the conv 2d op in the Vulkan runtime by de-vectorizing the sum and moving the bias application to the end. Differential Revision: [D75551846](https://our.internmc.facebook.com/intern/diff/D75551846/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/glsl/conv2d.glsl | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl index c0ed9204227..5edccc47031 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -12,6 +12,8 @@ #define VEC4_T ${texel_type(DTYPE)} +#define T ${texel_component_type(DTYPE)} + #define op(X, A, B) ${OPERATOR} #include "indexing_utils.h" @@ -72,13 +74,24 @@ void main() { kstart.y += pos.z * kernel_size.y; // Perform the convolution by iterating over the overlay region. - VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); + T sum[4]; + sum[0] = T(0); + sum[1] = T(0); + sum[2] = T(0); + sum[3] = T(0); + const int ic4 = in_group_size / 4; for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) { for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) { for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) { - const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0); - const ivec4 kxs = kx + ivec4(0, 1, 2, 3); + const VEC4_T in_texel_v = texelFetch(t_in, ivec3(x, y, z4), 0); + + T in_texel[4]; + in_texel[0] = in_texel_v.x; + in_texel[1] = in_texel_v.y; + in_texel[2] = in_texel_v.z; + in_texel[3] = in_texel_v.w; + // To explain the calculation below, the contents of in_texel and the // group of 4 texels loaded from t_kernel are shown: @@ -112,13 +125,27 @@ void main() { // // which is expressed in the following statements. - sum = fma(in_texel.xxxx, texelFetch(t_kernel, ivec2(kxs.x, ky), 0), sum); - sum = fma(in_texel.yyyy, texelFetch(t_kernel, ivec2(kxs.y, ky), 0), sum); - sum = fma(in_texel.zzzz, texelFetch(t_kernel, ivec2(kxs.z, ky), 0), sum); - sum = fma(in_texel.wwww, texelFetch(t_kernel, ivec2(kxs.w, ky), 0), sum); + T k_tex_arr[16]; + for (int kc = 0; kc < 4; kc++) { + const VEC4_T k_tex = texelFetch(t_kernel, ivec2(kx + kc, ky), 0); + k_tex_arr[kc * 4 + 0] = k_tex.x; + k_tex_arr[kc * 4 + 1] = k_tex.y; + k_tex_arr[kc * 4 + 2] = k_tex.z; + k_tex_arr[kc * 4 + 3] = k_tex.w; + } + + for (int sc = 0; sc < 4; sc++) { + sum[0] = fma(in_texel[sc], k_tex_arr[sc * 4 + 0], sum[0]); + sum[1] = fma(in_texel[sc], k_tex_arr[sc * 4 + 1], sum[1]); + sum[2] = fma(in_texel[sc], k_tex_arr[sc * 4 + 2], sum[2]); + sum[3] = fma(in_texel[sc], k_tex_arr[sc * 4 + 3], sum[3]); + } } } } - imageStore(t_out, pos, op(sum, out_min, out_max)); + const VEC4_T bias = texelFetch(t_bias, ivec2(pos.z, 0), 0); + const VEC4_T out_sum = VEC4_T(sum[0], sum[1], sum[2], sum[3]) + bias; + + imageStore(t_out, pos, op(out_sum, out_min, out_max)); }