@@ -52,10 +52,9 @@ $else:
5252
5353layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
5454
55- // Constants for reduction
5655#define NWORKERS 64
5756
58- // Constant for small scale threshold
57+ // equivalent of the eps defined in the cpu implemnetation
5958#define SMALL_SCALE_THRESHOLD 6.1e-5
6059
6160// Shared memory for reduction - must match local work group size
@@ -70,11 +69,10 @@ void calculate_scale_and_zero_point(
7069 int qmax,
7170 out float scale_val,
7271 out int zero_point_val) {
73- // Ensure the range includes zero
72+ // ensure we have zero included in our range
7473 min_val = min (min_val, 0.0 );
7574 max_val = max (max_val, 0.0 );
7675
77- // Calculate scale
7876 scale_val = (max_val - min_val) / float (qmax - qmin);
7977
8078 // Handle zero or very small scale
@@ -122,16 +120,14 @@ void calculate_scale_and_zero_point(
122120
123121$if MODE == "per_tensor":
124122 void main() {
125- // Single-Pass Hierarchical Reduction for per-tensor min/max
126123 uint global_id = gl_GlobalInvocationID.x;
127124 uint local_id = gl_LocalInvocationID.x;
128125 uint group_id = gl_WorkGroupID.x;
129126 uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
130127
131- // Calculate total number of elements in the input tensor
132128 uint total_elements = uint (t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
133129
134- // Phase 1: Each thread processes multiple elements with stride
130+ // Each thread processes multiple elements with stride
135131 float thread_min = 1.0 / 0.0 ; // +infinity
136132 float thread_max = - 1.0 / 0.0 ; // -infinity
137133 bool found_valid = false;
@@ -150,7 +146,7 @@ $if MODE == "per_tensor":
150146 }
151147 }
152148
153- // Phase 2: Intra-group reduction using shared memory
149+ // Intra-group reduction using shared memory
154150 shared_min[local_id] = thread_min;
155151 shared_max[local_id] = thread_max;
156152 barrier();
@@ -161,7 +157,6 @@ $if MODE == "per_tensor":
161157 float other_min = shared_min[local_id + stride];
162158 float other_max = shared_max[local_id + stride];
163159
164- // Handle infinity values properly
165160 if (! isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
166161 shared_min[local_id] = other_min;
167162 }
@@ -172,30 +167,26 @@ $if MODE == "per_tensor":
172167 barrier();
173168 }
174169
175- // Phase 3: Final result calculation (single workgroup only)
170+ // Final result calculation (single workgroup only)
176171 if (local_id == 0 ) {
177172 float global_min = shared_min[0 ];
178173 float global_max = shared_max[0 ];
179174
180- // Calculate final scale and zero_point
181175 float scale_val;
182176 int zero_point_val;
183177 calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
184178
185- // Write final results
186179 t_scale[0 ] = scale_val;
187180 t_zero_point[0 ] = zero_point_val;
188181 }
189182 }
190183$else :
191184 void main() {
192- // Per-token hierarchical reduction implementation with multiple tokens per workgroup
193185 uint global_id = gl_GlobalInvocationID.x;
194186 uint local_id = gl_LocalInvocationID.x;
195187 uint group_id = gl_WorkGroupID.x;
196188 uint total_workgroups = gl_NumWorkGroups.x;
197189
198- // Calculate total number of elements in the input tensor
199190 uint total_elements = uint (t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
200191 uint token_size = total_elements / uint (num_tokens);
201192
@@ -218,7 +209,7 @@ $else:
218209 uint token_start = token_id * token_size;
219210 uint token_end = token_start + token_size;
220211
221- // Phase 1: Each thread processes multiple elements within the token with stride
212+ // Each thread processes multiple elements within the token with stride
222213 float thread_min = 1.0 / 0.0 ; // +infinity
223214 float thread_max = - 1.0 / 0.0 ; // -infinity
224215 bool found_valid = false;
@@ -238,7 +229,7 @@ $else:
238229 }
239230 }
240231
241- // Phase 2: Intra-group reduction using shared memory
232+ // Intra-group reduction using shared memory
242233 shared_min[local_id] = thread_min;
243234 shared_max[local_id] = thread_max;
244235 barrier();
@@ -249,7 +240,6 @@ $else:
249240 float other_min = shared_min[local_id + stride];
250241 float other_max = shared_max[local_id + stride];
251242
252- // Handle infinity values properly
253243 if (! isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
254244 shared_min[local_id] = other_min;
255245 }
@@ -260,17 +250,15 @@ $else:
260250 barrier();
261251 }
262252
263- // Phase 3: Final calculation for this token
253+ // Final calculation for this token
264254 if (local_id == 0 ) {
265255 float token_min = shared_min[0 ];
266256 float token_max = shared_max[0 ];
267257
268- // Calculate scale and zero_point for this token
269258 float scale_val;
270259 int zero_point_val;
271260 calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
272261
273- // Write results for this token
274262 t_scale[token_id] = scale_val;
275263 t_zero_point[token_id] = zero_point_val;
276264 }
@@ -284,16 +272,14 @@ $else:
284272
285273$if MODE == "per_tensor":
286274 void main() {
287- // Multi-workgroup texture-based per-tensor quantization parameter calculation
288275 uint global_id = gl_GlobalInvocationID.x;
289276 uint local_id = gl_LocalInvocationID.x;
290277 uint group_id = gl_WorkGroupID.x;
291278 uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
292279
293- // Calculate total number of texels in the input tensor
294280 uint total_texels = uint (t_in_limits.x * t_in_limits.y * t_in_limits.z);
295281
296- // Phase 1: Each thread processes multiple texels with stride
282+ // Each thread processes multiple texels with stride
297283 float thread_min = 1.0 / 0.0 ; // +infinity
298284 float thread_max = - 1.0 / 0.0 ; // -infinity
299285 bool found_valid = false;
@@ -307,7 +293,6 @@ $if MODE == "per_tensor":
307293 uint x = remainder % uint (t_in_limits.x);
308294 ivec3 texel_pos = ivec3 (int (x), int (y), int (z));
309295
310- // Load texel data (4 float values)
311296 FVEC4_T texel_data = load_texel(t_in, texel_pos);
312297
313298 // For texture storage, we assume width-packed (packed_dim = 0)
@@ -369,7 +354,7 @@ $if MODE == "per_tensor":
369354 }
370355 }
371356
372- // Phase 2: Intra-workgroup reduction using shared memory
357+ // Intra-workgroup reduction using shared memory
373358 shared_min[local_id] = thread_min;
374359 shared_max[local_id] = thread_max;
375360 barrier();
@@ -380,7 +365,6 @@ $if MODE == "per_tensor":
380365 float other_min = shared_min[local_id + stride];
381366 float other_max = shared_max[local_id + stride];
382367
383- // Handle infinity values properly
384368 if (! isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
385369 shared_min[local_id] = other_min;
386370 }
@@ -391,31 +375,26 @@ $if MODE == "per_tensor":
391375 barrier();
392376 }
393377
394- // Phase 3: Final result calculation (single workgroup only for reliability)
378+ // Final result calculation (single workgroup only for reliability)
395379 if (local_id == 0 && group_id == 0 ) {
396380 float global_min = shared_min[0 ];
397381 float global_max = shared_max[0 ];
398382
399- // Calculate final scale and zero_point
400383 float scale_val;
401384 int zero_point_val;
402385 calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
403386
404- // Write final results to output textures
405387 write_texel(t_scale, ivec3 (0 , 0 , 0 ), vec4 (scale_val, 0.0 , 0.0 , 0.0 ));
406388 write_texel(t_zero_point, ivec3 (0 , 0 , 0 ), ivec4 (zero_point_val, 0 , 0 , 0 ));
407389 }
408390 }
409391$else :
410392 void main() {
411- // Texture-based per-token quantization parameter calculation
412393 // Each token is processed by multiple workgroups for parallel reduction
413-
414394 uint local_id = gl_LocalInvocationID.x;
415395 uint group_id = gl_WorkGroupID.x;
416396 uint total_workgroups = gl_NumWorkGroups.x;
417397
418- // Calculate total number of texels in the input tensor
419398 uint total_texels = uint (t_in_limits.x * t_in_limits.y * t_in_limits.z);
420399
421400 // Calculate texels per token (assuming last dimension contains the token data)
@@ -435,7 +414,7 @@ $else:
435414 uint token_start_texel = token_id * texels_per_token;
436415 uint token_end_texel = token_start_texel + texels_per_token;
437416
438- // Phase 1: Each thread processes multiple texels within the token
417+ // Each thread processes multiple texels within the token
439418 float thread_min = 1.0 / 0.0 ; // +infinity
440419 float thread_max = - 1.0 / 0.0 ; // -infinity
441420 bool found_valid = false;
@@ -449,7 +428,6 @@ $else:
449428 uint x = remainder % uint (t_in_limits.x);
450429 ivec3 texel_pos = ivec3 (int (x), int (y), int (z));
451430
452- // Load texel data (4 float values)
453431 FVEC4_T texel_data = load_texel(t_in, texel_pos);
454432
455433 // For texture storage, we assume width-packed (packed_dim = 0)
@@ -511,7 +489,7 @@ $else:
511489 }
512490 }
513491
514- // Phase 2: Intra-workgroup reduction using shared memory
492+ // Intra-workgroup reduction using shared memory
515493 shared_min[local_id] = thread_min;
516494 shared_max[local_id] = thread_max;
517495 barrier();
@@ -533,12 +511,11 @@ $else:
533511 barrier();
534512 }
535513
536- // Phase 3: Final calculation for this token
514+ // Final calculation for this token
537515 if (local_id == 0 ) {
538516 float token_min = shared_min[0 ];
539517 float token_max = shared_max[0 ];
540518
541- // Calculate scale and zero_point for this token
542519 float scale_val;
543520 int zero_point_val;
544521 calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
@@ -551,7 +528,6 @@ $else:
551528 uint out_x = out_remainder % uint (t_scale_limits.x);
552529 ivec3 out_pos = ivec3 (int (out_x), int (out_y), int (out_z));
553530
554- // Write results for this token
555531 write_texel(t_scale, out_pos, vec4 (scale_val, 0.0 , 0.0 , 0.0 ));
556532 write_texel(t_zero_point, out_pos, ivec4 (zero_point_val, 0 , 0 , 0 ));
557533 }
0 commit comments