Skip to content

Commit 232f5f5

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] choose_qparams op shaders and impl"
Creating the choose_qparams per_tensor and per_token logic shaders and impl which are linked with the testing framework Differential Revision: [D76436933](https://our.internmc.facebook.com/intern/diff/D76436933/) [ghstack-poisoned]
1 parent 1320734 commit 232f5f5

File tree

2 files changed

+37
-116
lines changed

2 files changed

+37
-116
lines changed

backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@ $else:
5252

5353
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5454

55-
// Constants for reduction
5655
#define NWORKERS 64
5756

58-
// Constant for small scale threshold
57+
// equivalent of the eps defined in the cpu implemnetation
5958
#define SMALL_SCALE_THRESHOLD 6.1e-5
6059

6160
// Shared memory for reduction - must match local work group size
@@ -70,11 +69,10 @@ void calculate_scale_and_zero_point(
7069
int qmax,
7170
out float scale_val,
7271
out int zero_point_val) {
73-
// Ensure the range includes zero
72+
// ensure we have zero included in our range
7473
min_val = min(min_val, 0.0);
7574
max_val = max(max_val, 0.0);
7675

77-
// Calculate scale
7876
scale_val = (max_val - min_val) / float(qmax - qmin);
7977

8078
// Handle zero or very small scale
@@ -122,16 +120,14 @@ void calculate_scale_and_zero_point(
122120

123121
$if MODE == "per_tensor":
124122
void main() {
125-
// Single-Pass Hierarchical Reduction for per-tensor min/max
126123
uint global_id = gl_GlobalInvocationID.x;
127124
uint local_id = gl_LocalInvocationID.x;
128125
uint group_id = gl_WorkGroupID.x;
129126
uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
130127

131-
// Calculate total number of elements in the input tensor
132128
uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
133129

134-
// Phase 1: Each thread processes multiple elements with stride
130+
// Each thread processes multiple elements with stride
135131
float thread_min = 1.0/0.0; // +infinity
136132
float thread_max = -1.0/0.0; // -infinity
137133
bool found_valid = false;
@@ -150,7 +146,7 @@ $if MODE == "per_tensor":
150146
}
151147
}
152148

153-
// Phase 2: Intra-group reduction using shared memory
149+
// Intra-group reduction using shared memory
154150
shared_min[local_id] = thread_min;
155151
shared_max[local_id] = thread_max;
156152
barrier();
@@ -161,7 +157,6 @@ $if MODE == "per_tensor":
161157
float other_min = shared_min[local_id + stride];
162158
float other_max = shared_max[local_id + stride];
163159

164-
// Handle infinity values properly
165160
if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
166161
shared_min[local_id] = other_min;
167162
}
@@ -172,30 +167,26 @@ $if MODE == "per_tensor":
172167
barrier();
173168
}
174169

175-
// Phase 3: Final result calculation (single workgroup only)
170+
// Final result calculation (single workgroup only)
176171
if (local_id == 0) {
177172
float global_min = shared_min[0];
178173
float global_max = shared_max[0];
179174

180-
// Calculate final scale and zero_point
181175
float scale_val;
182176
int zero_point_val;
183177
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
184178

185-
// Write final results
186179
t_scale[0] = scale_val;
187180
t_zero_point[0] = zero_point_val;
188181
}
189182
}
190183
$else:
191184
void main() {
192-
// Per-token hierarchical reduction implementation with multiple tokens per workgroup
193185
uint global_id = gl_GlobalInvocationID.x;
194186
uint local_id = gl_LocalInvocationID.x;
195187
uint group_id = gl_WorkGroupID.x;
196188
uint total_workgroups = gl_NumWorkGroups.x;
197189

198-
// Calculate total number of elements in the input tensor
199190
uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
200191
uint token_size = total_elements / uint(num_tokens);
201192

@@ -218,7 +209,7 @@ $else:
218209
uint token_start = token_id * token_size;
219210
uint token_end = token_start + token_size;
220211

221-
// Phase 1: Each thread processes multiple elements within the token with stride
212+
// Each thread processes multiple elements within the token with stride
222213
float thread_min = 1.0/0.0; // +infinity
223214
float thread_max = -1.0/0.0; // -infinity
224215
bool found_valid = false;
@@ -238,7 +229,7 @@ $else:
238229
}
239230
}
240231

241-
// Phase 2: Intra-group reduction using shared memory
232+
// Intra-group reduction using shared memory
242233
shared_min[local_id] = thread_min;
243234
shared_max[local_id] = thread_max;
244235
barrier();
@@ -249,7 +240,6 @@ $else:
249240
float other_min = shared_min[local_id + stride];
250241
float other_max = shared_max[local_id + stride];
251242

252-
// Handle infinity values properly
253243
if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
254244
shared_min[local_id] = other_min;
255245
}
@@ -260,17 +250,15 @@ $else:
260250
barrier();
261251
}
262252

263-
// Phase 3: Final calculation for this token
253+
// Final calculation for this token
264254
if (local_id == 0) {
265255
float token_min = shared_min[0];
266256
float token_max = shared_max[0];
267257

268-
// Calculate scale and zero_point for this token
269258
float scale_val;
270259
int zero_point_val;
271260
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
272261

273-
// Write results for this token
274262
t_scale[token_id] = scale_val;
275263
t_zero_point[token_id] = zero_point_val;
276264
}
@@ -284,16 +272,14 @@ $else:
284272

285273
$if MODE == "per_tensor":
286274
void main() {
287-
// Multi-workgroup texture-based per-tensor quantization parameter calculation
288275
uint global_id = gl_GlobalInvocationID.x;
289276
uint local_id = gl_LocalInvocationID.x;
290277
uint group_id = gl_WorkGroupID.x;
291278
uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
292279

293-
// Calculate total number of texels in the input tensor
294280
uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z);
295281

296-
// Phase 1: Each thread processes multiple texels with stride
282+
// Each thread processes multiple texels with stride
297283
float thread_min = 1.0/0.0; // +infinity
298284
float thread_max = -1.0/0.0; // -infinity
299285
bool found_valid = false;
@@ -307,7 +293,6 @@ $if MODE == "per_tensor":
307293
uint x = remainder % uint(t_in_limits.x);
308294
ivec3 texel_pos = ivec3(int(x), int(y), int(z));
309295

310-
// Load texel data (4 float values)
311296
FVEC4_T texel_data = load_texel(t_in, texel_pos);
312297

313298
// For texture storage, we assume width-packed (packed_dim = 0)
@@ -369,7 +354,7 @@ $if MODE == "per_tensor":
369354
}
370355
}
371356

372-
// Phase 2: Intra-workgroup reduction using shared memory
357+
// Intra-workgroup reduction using shared memory
373358
shared_min[local_id] = thread_min;
374359
shared_max[local_id] = thread_max;
375360
barrier();
@@ -380,7 +365,6 @@ $if MODE == "per_tensor":
380365
float other_min = shared_min[local_id + stride];
381366
float other_max = shared_max[local_id + stride];
382367

383-
// Handle infinity values properly
384368
if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
385369
shared_min[local_id] = other_min;
386370
}
@@ -391,31 +375,26 @@ $if MODE == "per_tensor":
391375
barrier();
392376
}
393377

394-
// Phase 3: Final result calculation (single workgroup only for reliability)
378+
// Final result calculation (single workgroup only for reliability)
395379
if (local_id == 0 && group_id == 0) {
396380
float global_min = shared_min[0];
397381
float global_max = shared_max[0];
398382

399-
// Calculate final scale and zero_point
400383
float scale_val;
401384
int zero_point_val;
402385
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
403386

404-
// Write final results to output textures
405387
write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0));
406388
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0));
407389
}
408390
}
409391
$else:
410392
void main() {
411-
// Texture-based per-token quantization parameter calculation
412393
// Each token is processed by multiple workgroups for parallel reduction
413-
414394
uint local_id = gl_LocalInvocationID.x;
415395
uint group_id = gl_WorkGroupID.x;
416396
uint total_workgroups = gl_NumWorkGroups.x;
417397

418-
// Calculate total number of texels in the input tensor
419398
uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z);
420399

421400
// Calculate texels per token (assuming last dimension contains the token data)
@@ -435,7 +414,7 @@ $else:
435414
uint token_start_texel = token_id * texels_per_token;
436415
uint token_end_texel = token_start_texel + texels_per_token;
437416

438-
// Phase 1: Each thread processes multiple texels within the token
417+
// Each thread processes multiple texels within the token
439418
float thread_min = 1.0/0.0; // +infinity
440419
float thread_max = -1.0/0.0; // -infinity
441420
bool found_valid = false;
@@ -449,7 +428,6 @@ $else:
449428
uint x = remainder % uint(t_in_limits.x);
450429
ivec3 texel_pos = ivec3(int(x), int(y), int(z));
451430

452-
// Load texel data (4 float values)
453431
FVEC4_T texel_data = load_texel(t_in, texel_pos);
454432

455433
// For texture storage, we assume width-packed (packed_dim = 0)
@@ -511,7 +489,7 @@ $else:
511489
}
512490
}
513491

514-
// Phase 2: Intra-workgroup reduction using shared memory
492+
// Intra-workgroup reduction using shared memory
515493
shared_min[local_id] = thread_min;
516494
shared_max[local_id] = thread_max;
517495
barrier();
@@ -533,12 +511,11 @@ $else:
533511
barrier();
534512
}
535513

536-
// Phase 3: Final calculation for this token
514+
// Final calculation for this token
537515
if (local_id == 0) {
538516
float token_min = shared_min[0];
539517
float token_max = shared_max[0];
540518

541-
// Calculate scale and zero_point for this token
542519
float scale_val;
543520
int zero_point_val;
544521
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
@@ -551,7 +528,6 @@ $else:
551528
uint out_x = out_remainder % uint(t_scale_limits.x);
552529
ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z));
553530

554-
// Write results for this token
555531
write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0));
556532
write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0));
557533
}

backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp

Lines changed: 23 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,9 @@ utils::uvec3 choose_qparams_global_wg_size(
5454
(void)shader;
5555
(void)resize_args;
5656

57-
// For global reduction, we need to process the entire input tensor
58-
const ValueRef input = args.at(0).refs.at(0);
59-
60-
if (graph->is_buffer_storage(input)) {
61-
const uint32_t local_threads = 64; // From choose_qparams_local_wg_size
62-
63-
// For per-tensor quantization, use SINGLE WORKGROUP approach to avoid
64-
// complex multi-workgroup synchronization issues that cause race
65-
// conditions. A single workgroup with 64 threads can efficiently process
66-
// large tensors by having each thread process multiple elements with
67-
// stride.
68-
69-
// Return single workgroup with 64 threads
70-
return {local_threads, 1u, 1u};
71-
} else {
72-
// For texture storage, use single workgroup approach for reliability
73-
const uint32_t local_threads = 64; // From choose_qparams_local_wg_size
57+
const uint32_t local_threads = 64; // From choose_qparams_local_wg_size
7458

75-
// Return single workgroup with 64 threads
76-
return {local_threads, 1u, 1u};
77-
}
59+
return {local_threads, 1u, 1u};
7860
}
7961

8062
utils::uvec3 choose_qparams_local_wg_size(
@@ -90,14 +72,10 @@ utils::uvec3 choose_qparams_local_wg_size(
9072
const ValueRef input = args.at(0).refs.at(0);
9173

9274
if (graph->is_buffer_storage(input)) {
93-
// For hierarchical reduction, use 64 threads per work group for better
94-
// efficiency This provides better GPU utilization while still being
95-
// manageable for shared memory
96-
9775
const uint32_t local_threads = 64;
76+
9877
return {local_threads, 1u, 1u};
9978
} else {
100-
// For texture storage, use default local workgroup size
10179
return graph->create_local_wg_size(global_workgroup_size);
10280
}
10381
}
@@ -112,57 +90,27 @@ utils::uvec3 choose_qparams_per_token_global_wg_size(
11290

11391
const ValueRef input = args.at(0).refs.at(0);
11492

115-
if (graph->is_buffer_storage(input)) {
116-
// For per-token reduction, we need one workgroup per token
117-
// Calculate number of tokens (product of all dimensions except the last
118-
// one)
119-
int64_t num_tokens = 1;
120-
const auto input_sizes = graph->sizes_of(input);
121-
for (size_t i = 0; i < input_sizes.size() - 1; i++) {
122-
num_tokens *= input_sizes[i];
123-
}
124-
125-
// GPU hardware limits: Most GPUs support max ~65535 workgroups per
126-
// dimension
127-
const uint32_t max_workgroups = 65535;
128-
const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size
129-
130-
// Clamp number of workgroups to hardware limits
131-
uint32_t clamped_workgroups =
132-
std::min(static_cast<uint32_t>(num_tokens), max_workgroups);
133-
134-
// If we have more tokens than workgroups, each workgroup will process
135-
// multiple tokens
136-
137-
// Calculate total threads needed
138-
const uint32_t total_threads_x = clamped_workgroups * local_x;
139-
const uint32_t total_threads_y = 1u;
140-
const uint32_t total_threads_z = 1u;
141-
142-
return {total_threads_x, total_threads_y, total_threads_z};
143-
} else {
144-
// For texture storage, calculate number of tokens
145-
int64_t num_tokens = 1;
146-
const auto input_sizes = graph->sizes_of(input);
147-
for (size_t i = 0; i < input_sizes.size() - 1; i++) {
148-
num_tokens *= input_sizes[i];
149-
}
150-
151-
// For texture storage, clamp to reasonable limits for performance
152-
// Large token counts (>1024) can cause very slow execution
153-
const uint32_t max_reasonable_tokens = 1024;
154-
const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size
155-
156-
uint32_t clamped_workgroups =
157-
std::min(static_cast<uint32_t>(num_tokens), max_reasonable_tokens);
158-
159-
// Calculate total threads needed
160-
const uint32_t total_threads_x = clamped_workgroups * local_x;
161-
const uint32_t total_threads_y = 1u;
162-
const uint32_t total_threads_z = 1u;
163-
164-
return {total_threads_x, total_threads_y, total_threads_z};
93+
// For per-token reduction, we need one workgroup per token
94+
// Calculate number of tokens (product of all dimensions except the last
95+
// one)
96+
int64_t num_tokens = 1;
97+
const auto input_sizes = graph->sizes_of(input);
98+
for (size_t i = 0; i < input_sizes.size() - 1; i++) {
99+
num_tokens *= input_sizes[i];
165100
}
101+
102+
const uint32_t max_workgroups = 65535;
103+
const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size
104+
105+
// Clamp number of workgroups to avoid being slow
106+
uint32_t clamped_workgroups =
107+
std::min(static_cast<uint32_t>(num_tokens), max_workgroups);
108+
109+
// If we have more tokens than workgroups, each workgroup will process
110+
// multiple tokens
111+
const uint32_t total_threads_x = clamped_workgroups * local_x;
112+
113+
return {total_threads_x, 1u, 1u};
166114
}
167115

168116
utils::uvec3 choose_qparams_per_token_local_wg_size(
@@ -178,13 +126,10 @@ utils::uvec3 choose_qparams_per_token_local_wg_size(
178126
const ValueRef input = args.at(0).refs.at(0);
179127

180128
if (graph->is_buffer_storage(input)) {
181-
// For per-token reduction, each workgroup processes one token
182-
// Use 64 threads per work group to match shared memory allocation
183129
const uint32_t local_threads = 64;
184130

185131
return {local_threads, 1u, 1u};
186132
} else {
187-
// For texture storage, use default local workgroup size
188133
return graph->create_local_wg_size(global_workgroup_size);
189134
}
190135
}

0 commit comments

Comments
 (0)