@@ -88,10 +88,18 @@ void main() {
8888 ipos[i] = pos[i] * stride - padding;
8989 }
9090
91- vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
92- sum[0 ] = texelFetch(t_bias, ivec2 (gpos.z, 0 ), 0 );
93- for (int i = 1 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
94- sum[i] = sum[0 ];
91+ // Final output array where each element is a tensor value.
92+ // Tuple of consecutive 4 elements represents a single output texel.
93+ float sum[TILE_SIZE_X * TILE_SIZE_Y * 4 ];
94+
95+ const vec4 bias = texelFetch(t_bias, ivec2 (gpos.z, 0 ), 0 );
96+
97+ // Initialize the output array with the bias value
98+ for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y * 4 ; i += 4 ) {
99+ sum[i] = bias.x;
100+ sum[i + 1 ] = bias.y;
101+ sum[i + 2 ] = bias.z;
102+ sum[i + 3 ] = bias.w;
95103 }
96104
97105 int z4 = 0 ;
@@ -100,14 +108,26 @@ void main() {
100108 // During prepacking, the weight tensor has been permuted so that the
101109 // channel (IC) dim is along the x-axis, and the batch (OC) dim is along
102110 // the z-axis.
103- const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (0 , 0 ));
104- const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (1 , 0 ));
105- const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (2 , 0 ));
106- const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (3 , 0 ));
111+ float kernel_values[4 * 4 ]; // 4 channels, 4 elements per channel
112+
113+ // Load kernel values from texels to array
114+ for (int i = 0 ; i < 4 ; ++ i) {
115+ const vec4 k_tex = texelFetch(t_kernel, ivec2 (z + i, gpos.z), 0 );
116+ kernel_values[i * 4 + 0 ] = k_tex.x;
117+ kernel_values[i * 4 + 1 ] = k_tex.y;
118+ kernel_values[i * 4 + 2 ] = k_tex.z;
119+ kernel_values[i * 4 + 3 ] = k_tex.w;
120+ }
107121
108- #pragma unroll
109122 for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
110123 const vec4 in_tex = texelFetch(t_in, ivec3 (ipos[i], z4), 0 );
124+ // Load the input texel into an array
125+ float tex_values[4 ];
126+ tex_values[0 ] = in_tex.x;
127+ tex_values[1 ] = in_tex.y;
128+ tex_values[2 ] = in_tex.z;
129+ tex_values[3 ] = in_tex.w;
130+
111131 // For 2x2 tile size algorithm works as follows.
112132 // To explain the calculations below, the contents of one in_tex and the
113133 // group of 4 texels loaded from t_kernel are shown:
@@ -141,18 +161,20 @@ void main() {
141161 //
142162 // which is what is expressed in the following calculations. This is done
143163 // for each output position.
144- sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
145- sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
146- sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
147- sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
164+ for (int j = 0 ; j < 4 ; ++ j) {
165+ sum[i * 4 + j] = tex_values[0 ] * kernel_values[0 + j] + sum[i * 4 + j];
166+ sum[i * 4 + j] = tex_values[1 ] * kernel_values[4 + j] + sum[i * 4 + j];
167+ sum[i * 4 + j] = tex_values[2 ] * kernel_values[8 + j] + sum[i * 4 + j];
168+ sum[i * 4 + j] = tex_values[3 ] * kernel_values[12 + j] + sum[i * 4 + j];
169+ }
148170 }
149171 }
150172
151173 for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
152174 const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
153175 const ivec3 pos = pos_shared[offset_pos_index(index)];
154176 if (all (lessThan (pos, out_limits.xyz))) {
155- imageStore(t_out, pos, op(sum[i] , out_min, out_max));
177+ imageStore(t_out, pos, op(vec4 ( sum[i * 4 ], sum[i * 4 + 1 ], sum[i * 4 + 2 ], sum[i * 4 + 3 ]) , out_min, out_max));
156178 }
157179 }
158180}
0 commit comments