Skip to content

Commit 038ff0b

Browse files
authored
[ET-VK] Fast path for choose_qparams (#14045)
The current implementations of `choose_qparams` are too slow to be practically usable. As a temporary workaround to unblock LLM optimizations, this diff/PR introduces a fast path for computing per-channel quantization parameters for 2D matrices in the form of the choose_qparams_per_row shader. Differential Revision: [D81800024](https://our.internmc.facebook.com/intern/diff/D81800024/)
1 parent 8da72c3 commit 038ff0b

File tree

8 files changed

+848
-8
lines changed

8 files changed

+848
-8
lines changed

backends/vulkan/op_registry.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def register_affine_quantization_op():
172172

173173
@update_features(
174174
[
175-
exir_ops.edge.torchao.choose_qparams_affine.default,
176175
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
177176
exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default,
178177
]
@@ -184,6 +183,20 @@ def register_torchao_quantization_op():
184183
)
185184

186185

186+
@update_features(
187+
exir_ops.edge.torchao.choose_qparams_affine.default,
188+
)
189+
def register_torchao_choose_qparams_affine():
190+
return OpFeatures(
191+
inputs_storage=utils.CONTIGUOUS_ANY,
192+
outputs_storage=[
193+
utils.CONTIGUOUS_BUFFER, # scales
194+
utils.CONTIGUOUS_BUFFER, # zero_points
195+
],
196+
supports_resize=True,
197+
)
198+
199+
187200
@update_features(
188201
[
189202
exir_ops.edge.aten.add.Tensor,

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const {
332332
return is_contiguous(idx);
333333
}
334334

335+
bool ComputeGraph::is_contiguous_texture_tensor(const ValueRef idx) const {
336+
if (!val_is_tensor(idx)) {
337+
return false;
338+
}
339+
if (is_buffer_storage(idx)) {
340+
return false;
341+
}
342+
return has_standard_axis_map(idx) && packed_dim_of(idx) == 0;
343+
}
344+
335345
bool ComputeGraph::is_standard_channels_packed_texture_tensor(
336346
const ValueRef idx) const {
337347
if (!val_is_tensor(idx)) {
@@ -343,15 +353,50 @@ bool ComputeGraph::is_standard_channels_packed_texture_tensor(
343353
return has_standard_axis_map(idx) && packed_dim_of(idx) == 2;
344354
}
345355

346-
bool ComputeGraph::is_standard_width_packed_texture_tensor(
356+
bool ComputeGraph::is_2d_matrix(const ValueRef idx) const {
357+
std::vector<int64_t> sizes = sizes_of(idx);
358+
const size_t ndim = sizes.size();
359+
if (sizes.size() < 2) {
360+
return false;
361+
}
362+
if (sizes.size() == 2) {
363+
return true;
364+
}
365+
366+
// Check that outermost dims have size of 1
367+
for (int d = 0; d < ndim - 2; d++) {
368+
if (sizes[d] != 1) {
369+
return false;
370+
}
371+
}
372+
373+
return true;
374+
}
375+
376+
bool ComputeGraph::is_vectorizable_contiguous_2d_matrix(
347377
const ValueRef idx) const {
348-
if (!val_is_tensor(idx)) {
378+
if (!is_2d_matrix(idx)) {
349379
return false;
350380
}
351381
if (is_buffer_storage(idx)) {
382+
return is_contiguous_buffer_tensor(idx) &&
383+
size_at<int32_t>(-1, idx) % 4 == 0;
384+
}
385+
return is_contiguous_texture_tensor(idx);
386+
}
387+
388+
bool ComputeGraph::is_vectorizable_width_packed_tensor(
389+
const ValueRef idx) const {
390+
// Not a tensor - return false
391+
if (!val_is_tensor(idx)) {
352392
return false;
353393
}
354-
return has_standard_axis_map(idx) && packed_dim_of(idx) == 0;
394+
if (is_buffer_storage(idx)) {
395+
return is_contiguous_buffer_tensor(idx) &&
396+
size_at<int32_t>(-1, idx) % 4 == 0;
397+
}
398+
399+
return is_standard_channels_packed_texture_tensor(idx);
355400
}
356401

357402
ValueRef ComputeGraph::add_tensor(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,18 +382,40 @@ class ComputeGraph final {
382382
* 1. The value at `idx` is a tensor
383383
* 2. The tensor at `idx` has texture storage
384384
* 3. The texture backed tensor at `idx` has a standard axis mapping
385-
* 4. The texture backed tensor at `idx` is channels packed
385+
* 4. The texture backed tensor at `idx` is width packed
386386
*/
387-
bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const;
387+
bool is_contiguous_texture_tensor(const ValueRef idx) const;
388388

389389
/*
390390
* Checks that the following is true:
391391
* 1. The value at `idx` is a tensor
392392
* 2. The tensor at `idx` has texture storage
393393
* 3. The texture backed tensor at `idx` has a standard axis mapping
394-
* 4. The texture backed tensor at `idx` is width packed
394+
* 4. The texture backed tensor at `idx` is channels packed
395+
*/
396+
bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const;
397+
398+
/*
399+
* Checks that the value at `idx` is either a 2D tensor, or if the tensor has
400+
* more than 2 dims, the outermost dims have size of 1, i.e. can be squeezed
401+
* to be a 2D tensor.
402+
*/
403+
bool is_2d_matrix(const ValueRef idx) const;
404+
405+
/*
406+
* Same as the above, but also requires that the tensor is a contiguous
407+
* buffer with a width divisible by 4 or a standard width packed texture.
408+
*/
409+
bool is_vectorizable_contiguous_2d_matrix(const ValueRef idx) const;
410+
411+
/*
412+
* Checks that the following is true:
413+
* 1. The value at `idx` is a tensor
414+
* 2. The tensor at `idx` is width packed
415+
* 3. The tensor at `idx` has a standard axis mapping or is a contiguous
416+
* buffer
395417
*/
396-
bool is_standard_width_packed_texture_tensor(const ValueRef idx) const;
418+
bool is_vectorizable_width_packed_tensor(const ValueRef idx) const;
397419

398420
inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base)
399421
const {
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
#define T ${texel_load_component_type(DTYPE, STORAGE)}
14+
15+
#define NUM_OUTPUTS_PER_WG ${NUM_OUTPUTS_PER_WG}
16+
#define NUM_WORKERS_PER_OUTPUT ${NUM_WORKERS_PER_OUTPUT}
17+
18+
// Maximum total threads in a work group
19+
#define MAX_THREADS 256
20+
21+
${define_active_storage_type(STORAGE)}
22+
${define_required_extensions("int8")}
23+
24+
#extension GL_EXT_control_flow_attributes : require
25+
26+
layout(std430) buffer;
27+
28+
#include "common.glslh"
29+
30+
${layout_declare_tensor(B, "w", "t_scales", "float", "buffer")}
31+
${layout_declare_tensor(B, "w", "t_zps", "int", "buffer")}
32+
${layout_declare_tensor(B, "r", "t_input", DTYPE, STORAGE, is_scalar_array=False)}
33+
34+
${layout_declare_ubo(B, "ivec4", "input_sizes")}
35+
36+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37+
38+
layout(push_constant) uniform PushConstants {
39+
int quant_min;
40+
int quant_max;
41+
};
42+
43+
// Shared memory for cooperative min/max finding
44+
shared T shared_min[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];
45+
shared T shared_max[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];
46+
47+
const float SMALL_SCALE_THRESHOLD = 6.1e-5;
48+
49+
void calculate_scale_and_zero_point(
50+
float min_val,
51+
float max_val,
52+
int qmin,
53+
int qmax,
54+
out float scale,
55+
out int8_t zero_point) {
56+
57+
// Extend the [min, max] interval to ensure it contains 0
58+
min_val = min(min_val, 0.0);
59+
max_val = max(max_val, 0.0);
60+
61+
// Calculate scale
62+
scale = (max_val - min_val) / float(qmax - qmin);
63+
64+
// Handle special cases for scale
65+
if (scale == 0.0 || isinf(1.0 / scale)) {
66+
scale = 0.1;
67+
}
68+
69+
// Cut off small scale
70+
if (scale < SMALL_SCALE_THRESHOLD) {
71+
float org_scale = scale;
72+
scale = SMALL_SCALE_THRESHOLD;
73+
// Adjust the min and max based on the new scale
74+
if (min_val == 0.0) {
75+
max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin);
76+
} else if (max_val == 0.0) {
77+
min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin);
78+
} else {
79+
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
80+
min_val *= amplifier;
81+
max_val *= amplifier;
82+
}
83+
}
84+
85+
// Zero-point computation
86+
float zero_point_from_min = float(qmin) - min_val / scale;
87+
float zero_point_from_max = float(qmax) - max_val / scale;
88+
float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale);
89+
float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale);
90+
91+
float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error
92+
? zero_point_from_min
93+
: zero_point_from_max;
94+
95+
// Nudge zero point to be an integer
96+
int nudged_zero_point;
97+
if (initial_zero_point < float(qmin)) {
98+
nudged_zero_point = qmin;
99+
} else if (initial_zero_point > float(qmax)) {
100+
nudged_zero_point = qmax;
101+
} else {
102+
nudged_zero_point = int(round(initial_zero_point));
103+
}
104+
105+
zero_point = int8_t(nudged_zero_point);
106+
}
107+
108+
#ifdef USING_BUFFER
109+
110+
VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) {
111+
return t_input[(y * ntexels_x) + x4];
112+
}
113+
114+
#else // USING_TEXTURE
115+
116+
VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) {
117+
return texelFetch(t_input, ivec3(x4, y, 0), 0);
118+
}
119+
120+
#endif // USING_BUFFER
121+
122+
void main() {
123+
const int worker_id = int(gl_LocalInvocationID.x);
124+
const int output_id = int(gl_LocalInvocationID.y);
125+
126+
const int output_y = int(gl_GlobalInvocationID.y);
127+
128+
if (output_y >= input_sizes.y) {
129+
return;
130+
}
131+
132+
// Input is 2D tensor (height x width), width-packed
133+
// Each channel corresponds to a row in the tensor
134+
const int X4 = div_4(input_sizes.x);
135+
136+
// Initialize thread-local min/max
137+
float local_min = 1e30;
138+
float local_max = -1e30;
139+
140+
// Each thread processes elements along their assigned output_id with stride
141+
// NUM_WORKERS_PER_OUTPUT
142+
for (int x4 = worker_id; x4 < X4; x4 += NUM_WORKERS_PER_OUTPUT) {
143+
VEC4_T in_texel = load_input_x4(x4, output_y, X4);
144+
for (int i = 0; i < 4; i++) {
145+
local_min = min(local_min, in_texel[i]);
146+
local_max = max(local_max, in_texel[i]);
147+
}
148+
}
149+
150+
// Store thread-local results in shared memory
151+
shared_min[output_id][worker_id] = local_min;
152+
shared_max[output_id][worker_id] = local_max;
153+
154+
memoryBarrierShared();
155+
barrier();
156+
157+
// Tree reduction to compute the overall result
158+
for (int i = NUM_WORKERS_PER_OUTPUT / 2; i > 0; i >>= 1) {
159+
if (worker_id < i) {
160+
shared_min[output_id][worker_id] = min(
161+
shared_min[output_id][worker_id],
162+
shared_min[output_id][worker_id + i]);
163+
shared_max[output_id][worker_id] = max(
164+
shared_max[output_id][worker_id],
165+
shared_max[output_id][worker_id + i]);
166+
}
167+
memoryBarrierShared();
168+
barrier();
169+
}
170+
171+
// Only first thread will write out result
172+
if (worker_id == 0) {
173+
local_min = shared_min[output_id][0];
174+
local_max = shared_max[output_id][0];
175+
176+
float scale;
177+
int8_t zero_point;
178+
calculate_scale_and_zero_point(
179+
local_min, local_max, quant_min, quant_max, scale, zero_point);
180+
181+
t_scales[output_y] = scale;
182+
t_zps[output_y] = zero_point;
183+
}
184+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
choose_qparams_per_row:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
NUM_OUTPUTS_PER_WG: 1
12+
NUM_WORKERS_PER_OUTPUT: 64
13+
generate_variant_forall:
14+
STORAGE:
15+
- VALUE: texture3d
16+
- VALUE: buffer
17+
DTYPE:
18+
- VALUE: float
19+
shader_variants:
20+
- NAME: choose_qparams_per_row_o1w64
21+
- NAME: choose_qparams_per_row_o4w16
22+
NUM_OUTPUTS_PER_WG: 4
23+
NUM_WORKERS_PER_OUTPUT: 16

0 commit comments

Comments
 (0)