Skip to content

Commit 352da28

Browse files
committed
[ET-VK] De vectorizing sum and moving bias application to the end in conv 2d op to improve performance.
This diff optimizes the conv 2d op in the Vulkan runtime by de-vectorizing the sum and moving the bias application to the end. Differential Revision: [D75551846](https://our.internmc.facebook.com/intern/diff/D75551846/) [ghstack-poisoned]
1 parent b2c1bc4 commit 352da28

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15+
#define T ${texel_component_type(DTYPE)}
16+
1517
#define op(X, A, B) ${OPERATOR}
1618

1719
#include "indexing_utils.h"
@@ -72,13 +74,24 @@ void main() {
7274
kstart.y += pos.z * kernel_size.y;
7375

7476
// Perform the convolution by iterating over the overlay region.
75-
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
77+
T sum[4];
78+
sum[0] = T(0);
79+
sum[1] = T(0);
80+
sum[2] = T(0);
81+
sum[3] = T(0);
82+
7683
const int ic4 = in_group_size / 4;
7784
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) {
7885
for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) {
7986
for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) {
80-
const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0);
81-
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
87+
const VEC4_T in_texel_v = texelFetch(t_in, ivec3(x, y, z4), 0);
88+
89+
T in_texel[4];
90+
in_texel[0] = in_texel_v.x;
91+
in_texel[1] = in_texel_v.y;
92+
in_texel[2] = in_texel_v.z;
93+
in_texel[3] = in_texel_v.w;
94+
8295

8396
// To explain the calculation below, the contents of in_texel and the
8497
// group of 4 texels loaded from t_kernel are shown:
@@ -112,13 +125,27 @@ void main() {
112125
//
113126
// which is expressed in the following statements.
114127

115-
sum = fma(in_texel.xxxx, texelFetch(t_kernel, ivec2(kxs.x, ky), 0), sum);
116-
sum = fma(in_texel.yyyy, texelFetch(t_kernel, ivec2(kxs.y, ky), 0), sum);
117-
sum = fma(in_texel.zzzz, texelFetch(t_kernel, ivec2(kxs.z, ky), 0), sum);
118-
sum = fma(in_texel.wwww, texelFetch(t_kernel, ivec2(kxs.w, ky), 0), sum);
128+
T k_tex_arr[16];
129+
for (int kc = 0; kc < 4; kc++) {
130+
const VEC4_T k_tex = texelFetch(t_kernel, ivec2(kx + kc, ky), 0);
131+
k_tex_arr[kc * 4 + 0] = k_tex.x;
132+
k_tex_arr[kc * 4 + 1] = k_tex.y;
133+
k_tex_arr[kc * 4 + 2] = k_tex.z;
134+
k_tex_arr[kc * 4 + 3] = k_tex.w;
135+
}
136+
137+
for (int sc = 0; sc < 4; sc++) {
138+
sum[0] = fma(in_texel[sc], k_tex_arr[sc * 4 + 0], sum[0]);
139+
sum[1] = fma(in_texel[sc], k_tex_arr[sc * 4 + 1], sum[1]);
140+
sum[2] = fma(in_texel[sc], k_tex_arr[sc * 4 + 2], sum[2]);
141+
sum[3] = fma(in_texel[sc], k_tex_arr[sc * 4 + 3], sum[3]);
142+
}
119143
}
120144
}
121145
}
122146

123-
imageStore(t_out, pos, op(sum, out_min, out_max));
147+
const VEC4_T bias = texelFetch(t_bias, ivec2(pos.z, 0), 0);
148+
const VEC4_T out_sum = VEC4_T(sum[0], sum[1], sum[2], sum[3]) + bias;
149+
150+
imageStore(t_out, pos, op(out_sum, out_min, out_max));
124151
}

0 commit comments

Comments
 (0)