@@ -43,14 +43,71 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
4343const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4444const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4545
46- #define SHARED_MEMORY_FACTOR 2
4746#define MAX_WORKGROUP_SIZE 64
4847
48+ // Shared memory factor increases shared memory allocation by a scale that should either be 1 or a power of 2.
49+ //
50+ // Increasing factor allows more data to be stored in shared memory and increase thread utilization during reduction.
51+ // Why? Because when performing reduction, the number of active threads becomes half in each iteration.
52+ // Increasing scaling factor increases the thread occupancy and hence utilize the GPU better.
53+ // eg.
54+ // If local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 1, 32 elements will be loaded into shared memory.
55+ // First iteration of reduce will have 16 threads sum up 32 elements.
56+ // Second iteration will have 8 threads sum up 16 elements from previous iteration and so on.
57+ // So thread utilization starts at 50%.
58+ //
59+ // By contrast if local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 2, 64 elements will be loaded into shared memory.
60+ // First iteration of reduce will have 32 threads sum up 64 elements.
61+ // Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
62+ // Thus thread utilization starts at 100%.
63+ #define SHARED_MEMORY_FACTOR 2
64+
4965#define offset_pos_index(index) ((index) + ((index) >> 2 ))
5066
5167shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
5268
53- // function to reduce input data in workgroup's x dimension
69+ // Function to reduce input data in workgroup's x dimension
70+ //
71+ // The implementation resembles reduction as depicted below
72+ // | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | 2 | 3 | 2 | 7 | 0 | 11 | 0 | 2 | current_stride -> 1
73+ // | / | / | / | / | / | / | / | /
74+ // | / | / | / | / | / | / | / | /
75+ // | / | / | / | / | / | / | / | /
76+ // | 11 | 1 | 9 | 1 | 2 | 2 | 8 | 5 | 5 | 3 | 9 | 7 | 11 | 11 | 2 | 2 | current_stride -> 2
77+ // | / | / | / | /
78+ // | / | / | / | /
79+ // | / | / | / | /
80+ // | 20 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |14 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 4
81+ // | / | /
82+ // | / | /
83+ // | / | /
84+ // | / | /
85+ // | / | /
86+ // | 30 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 8
87+ // | /
88+ // | /
89+ // | /
90+ // | /
91+ // | /
92+ // | /
93+ // | /
94+ // | /
95+ // | /
96+ // | 57 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride = -> 16
97+ //
98+ // Threads access shared index in following pattern
99+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 1
100+ // Shared Index | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | X | X | X | X | X | X | X | X | index *= 1
101+ //
102+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 2
103+ // Shared Index | 0 | 4 | 8 | 12 | X | X | X | X | X | X | X | X | X | X | X | X | index *= 2
104+ //
105+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 4
106+ // Shared Index | 0 | 8 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 4
107+ //
108+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 8
109+ // Shared Index | 0 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 8
110+
54111void reduce_input(const int width_stride, const int shared_idx_offset) {
55112 // wait for all shared memory writes to finish
56113 memoryBarrierShared();
@@ -70,10 +127,9 @@ void reduce_input(const int width_stride, const int shared_idx_offset) {
70127 }
71128}
72129
73- void main () {
130+ void reduce_non_packed_dim () {
74131 const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
75132 const int width = int (sizes.x);
76-
77133 ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
78134
79135 // width batch read stride
@@ -85,148 +141,177 @@ void main() {
85141 // local memory index for this thread
86142 const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
87143
88- // if packed dimension width
89- if (in_packed_dim != W_DIM) {
90- VEC4_T mean = VEC4_T(0 );
91- VEC4_T var = VEC4_T(0 );
92-
93- // Loop over the width in stride increments
94- for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
95- // Read input in shared memory
96- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
97- in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
98-
99- VEC4_T in_val = VEC4_T(0 );
100- if (all (lessThan (in_pos, out_limits))) {
101- in_val = load_texel(t_in, in_pos);
102- }
103- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
104- }
144+ VEC4_T mean = VEC4_T(0 );
145+ VEC4_T var = VEC4_T(0 );
105146
106- reduce_input(width_stride, shared_idx_offset);
107- mean += shared_input[offset_pos_index(shared_idx_offset)];
147+ // Loop over the width in stride increments
148+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
149+ // Read input in shared memory
150+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
151+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
152+
153+ VEC4_T in_val = VEC4_T(0 );
154+ if (all (lessThan (in_pos, out_limits))) {
155+ in_val = load_texel(t_in, in_pos);
156+ }
157+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
108158 }
109159
110- mean /= width;
160+ reduce_input(width_stride, shared_idx_offset);
161+ mean += shared_input[offset_pos_index(shared_idx_offset)];
162+ }
163+
164+ mean /= width;
111165
112- // Loop over the width in stride increments
113- for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
114- // Read input in shared memory
115- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
116- in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
166+ memoryBarrierShared();
167+ barrier();
117168
118- VEC4_T in_val = mean;
119- if (all (lessThan (in_pos, out_limits))) {
120- in_val = load_texel(t_in, in_pos);
121- }
169+ // Loop over the width in stride increments
170+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
171+ // Read input in shared memory
172+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
173+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
122174
123- const VEC4_T delta = in_val - mean;
124- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
175+ VEC4_T in_val = mean;
176+ if (all (lessThan (in_pos, out_limits))) {
177+ in_val = load_texel(t_in, in_pos);
125178 }
126179
127- reduce_input(width_stride, shared_idx_offset) ;
128- var += shared_input[offset_pos_index(shared_idx_offset)] ;
180+ const VEC4_T delta = in_val - mean ;
181+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta ;
129182 }
130183
131- var /= width;
184+ reduce_input(width_stride, shared_idx_offset);
185+ var += shared_input[offset_pos_index(shared_idx_offset)];
186+ }
132187
133- VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
134- VEC4_T offset = - rstd * mean;
188+ var /= width;
135189
136- VEC4_T v = load_texel(t_in, lpos);
137- VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
138- VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
139- VEC4_T outtex = (v * rstd + offset) * weight + bias;
140- if (all (lessThan (lpos, out_limits))) {
141- write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
142- }
190+ VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
191+ VEC4_T offset = - rstd * mean;
143192
144- if (gl_GlobalInvocationID.x == 0 ) {
145- write_texel(t_mean, lpos, mean);
146- write_texel(t_rstd, lpos, rstd);
147- }
148- } else {
149- const int last_packed_width_index = divup4(width) - 1 ;
150- T mean = T(0 );
151- T var = T(0 );
152- const int remain = width & 3 ;
153-
154- const int in_pos_x_limit = out_limits[in_axis_map.x];
155-
156- // Loop over the width in stride increments
157- for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
158- // Read input in shared memory
159- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
160- const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
161- in_pos[in_axis_map.x] = in_pos_x;
162-
163- VEC4_T in_val = VEC4_T(0 );
164- if (in_pos_x < in_pos_x_limit) {
165- in_val = load_texel(t_in, in_pos);
166- }
167-
168- if (in_pos_x == last_packed_width_index && remain != 0 ) {
169- const int remain_inv = 4 - remain;
170- in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
171- in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
172- in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
173- }
174-
175- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
193+ VEC4_T v = load_texel(t_in, lpos);
194+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
195+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
196+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
197+
198+ if (all (lessThan (lpos, out_limits))) {
199+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
200+ }
201+
202+ if (gl_GlobalInvocationID.x == 0 ) {
203+ write_texel(t_mean, lpos, mean);
204+ write_texel(t_rstd, lpos, rstd);
205+ }
206+ }
207+
208+ void reduce_packed_dim() {
209+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
210+ const int width = int (sizes.x);
211+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
212+
213+ // width batch read stride
214+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
215+
216+ // local memory starting offset for this thread
217+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
218+
219+ // local memory index for this thread
220+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
221+
222+ const int last_packed_width_index = divup4(width) - 1 ;
223+ T mean = T(0 );
224+ T var = T(0 );
225+ const int remain = width & 3 ;
226+
227+ const int in_pos_x_limit = out_limits[in_axis_map.x];
228+
229+ // Loop over the width in stride increments
230+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
231+ // Read input in shared memory
232+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
233+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
234+ in_pos[in_axis_map.x] = in_pos_x;
235+
236+ VEC4_T in_val = VEC4_T(0 );
237+ if (in_pos_x < in_pos_x_limit) {
238+ in_val = load_texel(t_in, in_pos);
176239 }
177240
178- reduce_input(width_stride, shared_idx_offset);
179- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
180- mean += val.x + val.y + val.z + val.w;
241+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
242+ const int remain_inv = 4 - remain;
243+ in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
244+ in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
245+ in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
246+ }
247+
248+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
181249 }
182250
183- mean /= width;
184-
185- // Loop over the width in stride increments
186- for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
187- // Read input in shared memory
188- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
189- const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
190- in_pos[in_axis_map.x] = in_pos_x;
191-
192- VEC4_T in_val = VEC4_T(mean);
193- if (in_pos_x < in_pos_x_limit) {
194- in_val = load_texel(t_in, in_pos);
195- }
196-
197- if (in_pos_x == last_packed_width_index && remain != 0 ) {
198- const int remain_inv = 4 - remain;
199- in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
200- in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
201- in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
202- }
203-
204- const VEC4_T delta = in_val - mean;
205- const VEC4_T delta2 = delta * delta;
206- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
251+ reduce_input(width_stride, shared_idx_offset);
252+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+ mean += val.x + val.y + val.z + val.w;
254+ }
255+
256+ mean /= width;
257+
258+ memoryBarrierShared();
259+ barrier();
260+
261+ // Loop over the width in stride increments
262+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
263+ // Read input in shared memory
264+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
265+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
266+ in_pos[in_axis_map.x] = in_pos_x;
267+
268+ VEC4_T in_val = VEC4_T(mean);
269+ if (in_pos_x < in_pos_x_limit) {
270+ in_val = load_texel(t_in, in_pos);
271+ }
272+
273+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
274+ const int remain_inv = 4 - remain;
275+ in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
276+ in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
277+ in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
207278 }
208279
209- reduce_input(width_stride, shared_idx_offset) ;
210- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)] ;
211- var += val.x + val.y + val.z + val.w ;
280+ const VEC4_T delta = in_val - mean ;
281+ const VEC4_T delta2 = delta * delta ;
282+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2 ;
212283 }
213284
214- var /= width;
285+ reduce_input(width_stride, shared_idx_offset);
286+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287+ var += val.x + val.y + val.z + val.w;
288+ }
289+
290+ var /= width;
215291
216- T rstd = pow (var + epsilon, T(- 0.5 ));
217- T offset = - rstd * mean;
292+ T rstd = pow (var + epsilon, T(- 0.5 ));
293+ T offset = - rstd * mean;
218294
219- VEC4_T v = load_texel(t_in, lpos);
220- VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
221- VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
222- VEC4_T outtex = (v * rstd + offset) * weight + bias;
223- if (all (lessThan (lpos, out_limits))) {
224- write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
225- }
295+ VEC4_T v = load_texel(t_in, lpos);
296+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
297+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
298+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
226299
227- if (gl_GlobalInvocationID.x == 0 ) {
228- write_texel(t_mean, lpos, VEC4_T(mean));
229- write_texel(t_rstd, lpos, VEC4_T(rstd));
230- }
300+ if (all (lessThan (lpos, out_limits))) {
301+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
302+ }
303+
304+ if (gl_GlobalInvocationID.x == 0 ) {
305+ write_texel(t_mean, lpos, VEC4_T(mean));
306+ write_texel(t_rstd, lpos, VEC4_T(rstd));
307+ }
308+ }
309+
310+ void main() {
311+ // if packed dimension width
312+ if (in_packed_dim != W_DIM) {
313+ reduce_non_packed_dim();
314+ } else {
315+ reduce_packed_dim();
231316 }
232317}
0 commit comments