@@ -60,9 +60,9 @@ const lowp int out_packed_dim = unhash_packed_dim(out_layout);
6060// First iteration of reduce will have 32 threads sum up 64 elements.
6161// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
6262// Thus thread utilization starts at 100%.
63- #define SHARED_MEMORY_FACTOR 2
63+ #define SHARED_MEMORY_FACTOR 1
6464
65- #define offset_pos_index(index) ((index) + ((index) >> 2 ))
65+ #define offset_pos_index(index) ((index) + ((index) >> 3 ))
6666
6767shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
6868
@@ -154,14 +154,13 @@ void reduce_non_packed_dim() {
154154 if (all (lessThan (in_pos, out_limits))) {
155155 in_val = load_texel(t_in, in_pos);
156156 }
157- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
157+ mean + = in_val;
158158 }
159-
160- reduce_input(width_stride, shared_idx_offset);
161- mean += shared_input[offset_pos_index(shared_idx_offset)];
162159 }
163160
164- mean /= width;
161+ shared_input[offset_pos_index(shared_idx)] = mean;
162+ reduce_input(width_stride, shared_idx_offset);
163+ mean = shared_input[offset_pos_index(shared_idx_offset)] / width;
165164
166165 memoryBarrierShared();
167166 barrier();
@@ -178,14 +177,13 @@ void reduce_non_packed_dim() {
178177 }
179178
180179 const VEC4_T delta = in_val - mean;
181- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
180+ var + = delta * delta;
182181 }
183-
184- reduce_input(width_stride, shared_idx_offset);
185- var += shared_input[offset_pos_index(shared_idx_offset)];
186182 }
187183
188- var /= width;
184+ shared_input[offset_pos_index(shared_idx)] = var;
185+ reduce_input(width_stride, shared_idx_offset);
186+ var = shared_input[offset_pos_index(shared_idx_offset)] / width;
189187
190188 VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
191189 VEC4_T offset = - rstd * mean;
@@ -226,6 +224,7 @@ void reduce_packed_dim() {
226224
227225 const int in_pos_x_limit = out_limits[in_axis_map.x];
228226
227+ VEC4_T accum = VEC4_T(0 );
229228 // Loop over the width in stride increments
230229 for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
231230 // Read input in shared memory
@@ -244,20 +243,20 @@ void reduce_packed_dim() {
244243 in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
245244 in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
246245 }
247-
248- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
246+ accum += in_val;
249247 }
250-
251- reduce_input(width_stride, shared_idx_offset);
252- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253- mean += val.x + val.y + val.z + val.w;
254248 }
255249
256- mean /= width;
250+ shared_input[offset_pos_index(shared_idx)] = accum;
251+ reduce_input(width_stride, shared_idx_offset);
252+ VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+ mean = (val.x + val.y + val.z + val.w) / width;
257254
258255 memoryBarrierShared();
259256 barrier();
260257
258+ VEC4_T delta2 = VEC4_T(0 );
259+
261260 // Loop over the width in stride increments
262261 for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
263262 // Read input in shared memory
@@ -278,16 +277,14 @@ void reduce_packed_dim() {
278277 }
279278
280279 const VEC4_T delta = in_val - mean;
281- const VEC4_T delta2 = delta * delta;
282- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
280+ delta2 += delta * delta;
283281 }
284-
285- reduce_input(width_stride, shared_idx_offset);
286- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287- var += val.x + val.y + val.z + val.w;
288282 }
289283
290- var /= width;
284+ shared_input[offset_pos_index(shared_idx)] = delta2;
285+ reduce_input(width_stride, shared_idx_offset);
286+ val = shared_input[offset_pos_index(shared_idx_offset)];
287+ var = (val.x + val.y + val.z + val.w) / width;
291288
292289 T rstd = pow (var + epsilon, T(- 0.5 ));
293290 T offset = - rstd * mean;
0 commit comments