Skip to content

Commit 6eb44b5

Browse files
author
morelos
committed
[ET-VK][Ops] torchao.choose_qparams_affine vulkan impl and shader
Pull Request resolved: #12577 # Changes * Implement `torchao.choose_qparams_affine` operator in Vulkan backend with comprehensive buffer storage support * Add block-wise quantization parameter computation in `choose_qparams_buffer.glsl` shader for configurable tensor block analysis * Extend quantization parameter infrastructure in `ChooseQParams.cpp` to handle affine transformations with configurable block sizes and multiple mapping types * Support three quantization mapping strategies: ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR for optimal parameter selection * Consolidated the logic for choosing scale and zero point between affine cases and regular quantized_decomposed cases. BE: Improved the documentation in the shader logic which is more detailed and clear # Motivation The existing Vulkan quantization infrastructure lacked support for the `torchao.choose_qparams_affine` operator, which is essential for computing optimal quantization parameters in dynamic quantization workflows. The `choose_qparams_affine` operator provides flexible block-wise parameter computation that analyzes statistical distributions within tensor blocks, enabling: * **Block-wise Parameter Computation**: Analyzes configurable tensor blocks to compute optimal scale and zero-point values, improving quantization accuracy for heterogeneous data distributions * **Multiple Mapping Types**: Supports ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR quantization strategies for different precision-performance trade-offs # Operator Description The `choose_qparams_affine` operator computes optimal quantization parameters (scale and zero_point) from floating-point tensor blocks using statistical analysis of data distributions. Block-wise parameter computation divides tensors into blocks and analyzes each block independently to determine the best quantization mapping for subsequent quantization operations. The parameter calculation varies by mapping type: - **ASYMMETRIC**: `scale = (max - min) / (quant_max - quant_min)`, `zero_point = quant_min - round(min / scale)` - **SYMMETRIC**: `scale = max_abs / ((quant_max - quant_min) / 2)`, `zero_point = midpoint` - **SYMMETRIC_NO_CLIPPING_ERR**: `scale = max(abs(min)/abs(quant_min), max/quant_max)`, `zero_point = midpoint` **Storage Requirements**: Input tensors must be floating-point (kFloat) with width-packed layout. Output scale/zero_point tensors use buffer storage. NOTE: Texture storage implementation is not supported due to complexity of block-wise coordinate mapping in 3D texture space. This will likely be necessary for better efficiency in the future. # Block-wise Parameter Computation Implementation Block-wise parameter computation enables fine-grained quantization analysis by dividing tensors into blocks and computing separate scale/zero_point parameters for each block. The implementation uses several key data structures computed in `ChooseQParams.cpp`: * **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks) * **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()` * **`num_blocks_vec`**: Number of blocks per dimension calculated as `ceil(tensor_size_whcn / block_size_vec)` to handle non-divisible dimensions * **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation * **`mapping_type`**: Integer encoding of quantization strategy (0=ASYMMETRIC, 1=SYMMETRIC, 2=SYMMETRIC_NO_CLIPPING_ERR) The block coordinate calculation uses: `block_coord = block_id_to_coord(block_id)` which converts linear block IDs back to 4D WHCN coordinates, then computes element ranges: `t0 = block_coord * blockSize` and `tEnd = t0 + blockSize` for nested loop iteration. # Shader Algorithm Overview ## Buffer Storage Implementation (`choose_qparams_buffer.glsl`) **Workgroup Configuration**: - **Global WG Size**: `{nBlocks, 1u, 1u}` where `nBlocks = total number of blocks` computed from `ceil(tensor_size / block_size)` for each dimension - **Local WG Size**: `{1u, 1u, 1u}` (single thread per block for simplicity, though could be optimized for larger blocks) **Block-wise Mode Algorithm**: The shader uses a sophisticated multi-level nested approach to process tensor blocks efficiently. Each thread is assigned multiple blocks using strided access: `for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE)` where `STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x`. For each assigned block, the algorithm performs several key steps: **1. Block Coordinate Conversion**: The `block_id_to_coord(block_id)` function converts linear block IDs to 4D WHCN coordinates using modular arithmetic. **2. Element Range Calculation**: Computes the inclusive start coordinate `t0 = bc * blockSize` and exclusive end coordinate `tEnd = t0 + blockSize` to define the block's element boundaries in tensor space. **3. Nested Loop Min/Max Scan**: Uses four nested loops to iterate through all elements within the block: `for (int n = t0.w; n < tEnd.w; ++n) for (int c = t0.z; c < tEnd.z; ++c) for (int h = t0.y; h < tEnd.y; ++h) for (int w = t0.x; w < tEnd.x; ++w)` Each element is accessed using `tidx_to_bufi(ivec4(w,h,c,n), t_in_strides)` to convert 4D tensor coordinates to linear buffer indices with proper stride handling. **4. Parameter Calculation**: Calls `calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp)` which implements the three mapping strategies: * **ASYMMETRIC (mapping_type=0)**: Maps full range [min, max] to [quant_min, quant_max] preserving data distribution * **SYMMETRIC (mapping_type=1)**: Centers around zero using `max_abs = max(abs(min), abs(max))` for balanced quantization * **SYMMETRIC_NO_CLIPPING_ERR (mapping_type=2)**: Computes separate scales for positive/negative ranges and uses the maximum to prevent clipping **Future Improvements**: Implement workgroup-level reduction for large blocks, optimize memory access patterns for better cache utilization, and explore texture storage implementation with simplified block alignment constraints. ghstack-source-id: 299473615 @exported-using-ghexport Differential Revision: [D78436638](https://our.internmc.facebook.com/intern/diff/D78436638/)
1 parent bd6bbe0 commit 6eb44b5

File tree

8 files changed

+1239
-372
lines changed

8 files changed

+1239
-372
lines changed

backends/vulkan/op_registry.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ def register_ephemeral_op(features: OpFeatures):
245245

246246
@update_features(
247247
[
248-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
249248
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
250249
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
250+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
251251
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
252252
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
253253
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
@@ -276,14 +276,32 @@ def register_quantization_op(features: OpFeatures):
276276
[
277277
exir_ops.edge.torchao.quantize_affine.default,
278278
exir_ops.edge.torchao.dequantize_affine.default,
279+
]
280+
)
281+
def register_affine_quantization_op(features: OpFeatures):
282+
features.texture_impl = TextureImplFeatures(
283+
uses_axis_map=False,
284+
valid_packed_dims={PackedDim.WIDTH},
285+
)
286+
features.buffer_impl = True
287+
features.resize_fn = True
288+
features.optimal_storage = VkStorageType.TEXTURE_3D
289+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
290+
features.handles_own_prepacking = True
291+
292+
return features
293+
294+
295+
@update_features(
296+
[
279297
exir_ops.edge.torchao.choose_qparams_affine.default,
280298
]
281299
)
282-
def register_torchao_quantization_op(features: OpFeatures):
283-
# TorchAO quantization operators - default to per-tensor behavior
284-
# Same features as standard quantization ops
300+
def register_choose_qparams_affine_op(features: OpFeatures):
301+
# Currently only created a rudimentary buffer implementation for choose_qparams_affine
302+
# since the reduction logic for blocks in texture3d is not trivial to implement in vulkan.
285303
features.texture_impl = TextureImplFeatures(
286-
uses_axis_map=True,
304+
uses_axis_map=False,
287305
valid_packed_dims={
288306
PackedDim.WIDTH,
289307
},
@@ -292,37 +310,6 @@ def register_torchao_quantization_op(features: OpFeatures):
292310
features.resize_fn = True
293311
features.optimal_storage = VkStorageType.BUFFER
294312

295-
def check_torchao_quantization_node(node: torch.fx.Node) -> bool:
296-
# Only per-tensor quantization is supported by the Vulkan backend.
297-
if len(node.args) < 2:
298-
return False
299-
300-
block_size = node.args[1]
301-
302-
if not isinstance(block_size, (list, tuple)):
303-
return False
304-
305-
input_arg = node.args[0]
306-
if not isinstance(input_arg, torch.fx.Node):
307-
return False
308-
309-
input_tensor = input_arg.meta.get("val", None)
310-
if not isinstance(input_tensor, FakeTensor):
311-
return False
312-
313-
input_shape = list(input_tensor.shape)
314-
315-
if len(block_size) != len(input_shape):
316-
return False
317-
318-
# Check if block_size matches input_shape exactly (per-tensor quantization)
319-
for i in range(len(block_size)):
320-
if block_size[i] != input_shape[i]:
321-
return False
322-
323-
return True
324-
325-
features.check_node_fn = check_torchao_quantization_node
326313
return features
327314

328315

backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,59 +9,67 @@
99
#ifndef CHOOSE_QPARAMS_GLSLH
1010
#define CHOOSE_QPARAMS_GLSLH
1111

12-
// Calculate scale and zero point from min and max values
13-
void calculate_scale_and_zero_point(
14-
float min_val,
15-
float max_val,
16-
int qmin,
17-
int qmax,
18-
float eps_threshold,
19-
out float scale_val,
20-
out int zero_point_val) {
21-
// ensure we have zero included in our range
22-
min_val = min(min_val, 0.0);
23-
max_val = max(max_val, 0.0);
12+
// mapping_type : 0 = ASYM, 1 = SYM, 2 = SYM_NO_CLIP
13+
void calc_scale_zp(
14+
float lo, float hi,
15+
int qmin, int qmax,
16+
int mapping_type,
17+
float eps,
18+
out float scale, out int zp) {
19+
// Handle case where lo and hi are +/-INF (no valid values found)
20+
if (isinf(lo) || isinf(hi)) {
21+
lo = 0.0;
22+
hi = 0.0;
23+
}
2424

25-
scale_val = (max_val - min_val) / float(qmax - qmin);
25+
float minv = min(lo, 0.0);
26+
float maxv = max(hi, 0.0);
2627

27-
// Handle zero or very small scale
28-
if (scale_val == 0.0 || isinf(1.0 / scale_val)) {
29-
scale_val = 0.1;
30-
}
28+
if (mapping_type == 0) { // asymmetric
29+
scale = (maxv - minv) / float(qmax - qmin);
30+
31+
// Handle zero or very small scale
32+
if (scale == 0.0 || isinf(1.0/scale)) {
33+
scale = eps;
34+
}
3135

32-
// Cut off small scale using the provided eps threshold
33-
if (scale_val < eps_threshold) {
34-
float org_scale = scale_val;
35-
scale_val = eps_threshold;
36+
if (scale < eps) {
37+
float org_scale = scale;
38+
scale = eps;
3639

37-
// Adjust min and max based on new scale
38-
if (min_val == 0.0) {
39-
max_val = eps_threshold * float(qmax - qmin);
40-
} else if (max_val == 0.0) {
41-
min_val = -eps_threshold * float(qmax - qmin);
42-
} else {
43-
float amplifier = eps_threshold / org_scale;
44-
min_val *= amplifier;
45-
max_val *= amplifier;
40+
// Adjust min and max based on new scale to maintain proper quantization range
41+
if (minv == 0.0) {
42+
maxv = eps * float(qmax - qmin);
43+
} else if (maxv == 0.0) {
44+
minv = -eps * float(qmax - qmin);
45+
} else {
46+
float amplifier = eps / org_scale;
47+
minv *= amplifier;
48+
maxv *= amplifier;
49+
}
50+
}
51+
52+
// Calculate zero_point (matching reference implementation)
53+
float initial_zero_point = float(qmin) - round(minv / scale);
54+
zp = int(clamp(initial_zero_point, float(qmin), float(qmax)));
55+
} else { // symmetric -- centred
56+
float scale_sym;
57+
if (mapping_type == 1) { // SYM
58+
float M = max(abs(minv), abs(maxv));
59+
scale_sym = M / (float(qmax - qmin) * 0.5);
60+
} else { // SYM_NO_CLIP
61+
float smin = abs(minv) / max(abs(float(qmin)), 1.0); // Avoid division by zero
62+
float smax = maxv / max(float(qmax), 1.0); // Avoid division by zero
63+
scale_sym = max(smin, smax);
4664
}
47-
}
4865

49-
// Calculate zero point
50-
float zero_point_from_min = float(qmin) - min_val / scale_val;
51-
float zero_point_from_max = float(qmax) - max_val / scale_val;
52-
float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val);
53-
float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val);
54-
float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error
55-
? zero_point_from_min
56-
: zero_point_from_max;
66+
// Handle zero or very small scale
67+
if (scale_sym == 0.0 || isinf(1.0/scale_sym)) {
68+
scale_sym = eps;
69+
}
5770

58-
// Nudge zero point to integer
59-
if (initial_zero_point < float(qmin)) {
60-
zero_point_val = qmin;
61-
} else if (initial_zero_point > float(qmax)) {
62-
zero_point_val = qmax;
63-
} else {
64-
zero_point_val = int(round(initial_zero_point));
71+
scale = max(scale_sym, eps);
72+
zp = int((qmax + qmin + 1) >> 1); // mid-point – always fits
6573
}
6674
}
6775

0 commit comments

Comments
 (0)