Skip to content

Commit 31241ef

Browse files
trivedivivekSS-JIA
authored andcommitted
[ET-VK] Removing tile input storage variable in conv_pw op and fetching the data in main loop. Also unrolling the main loop for performance improvement.
This diff removes the tile input storage array in_tex in the conv_pw op and fetches the data in the main loop for performance improvement. The main loop has also been unrolled for performance improvement. Differential Revision: [D64767314](https://our.internmc.facebook.com/intern/diff/D64767314/) ghstack-source-id: 252923517 Pull Request resolved: #6765
1 parent 1166669 commit 31241ef

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,15 @@ void main() {
8282
// During prepacking, the weight tensor has been permuted so that the
8383
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
8484
// the z-axis.
85-
vec4 in_tex[TILE_SIZE * TILE_SIZE];
8685
const vec4 ktex_0 = texelFetch(t_kernel, u16vec2(z + 0, gpos.z), 0);
8786
const vec4 ktex_1 = texelFetch(t_kernel, u16vec2(z + 1, gpos.z), 0);
8887
const vec4 ktex_2 = texelFetch(t_kernel, u16vec2(z + 2, gpos.z), 0);
8988
const vec4 ktex_3 = texelFetch(t_kernel, u16vec2(z + 3, gpos.z), 0);
9089

91-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
92-
in_tex[i] = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
93-
}
9490

91+
#pragma unroll
9592
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
93+
const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
9694
// For 2x2 tile size algorithm works as follows.
9795
// To explain the calculations below, the contents of one in_tex and the
9896
// group of 4 texels loaded from t_kernel are shown:
@@ -126,10 +124,10 @@ void main() {
126124
//
127125
// which is what is expressed in the following calculations. This is done
128126
// for each output position.
129-
sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]);
130-
sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]);
131-
sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]);
132-
sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]);
127+
sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
128+
sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
129+
sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
130+
sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
133131
}
134132
}
135133

0 commit comments

Comments
 (0)