Commit c3d2ccd
morelos
Update on "[ET-VK][Ops] torchao.quantize_affine vulkan impl and shader and cleanup"
# Changes
* Implement `torchao.quantize_affine` operator in Vulkan backend with comprehensive texture and buffer storage support
* Add block-wise quantization mode in `quantize_texture.glsl` and `quantize_buffer.glsl` shaders for configurable tensor block quantization
* Introduce comprehensive test suite in `affine_test.cpp` with multi-dimensional tensor validation and reference implementation
* Extend quantization infrastructure in `Quantize.cpp` to handle affine transformations with configurable block sizes and quantization parameters
BE: Improved the documentation in the shader logic which is more detailed and clear
NOTE: I delegated the quantize_affine and future affine operators through a new custom test file denoted as `affine_test.cpp` as the other quantization testing framework was getting a little large, and it makes more sense to separate the namespace between torchao and quantized_decomposed. I believe the _decomposed namespace is getting phased out in favor of this affine operator so deprecation will be easier in the future.
# Motivation
The existing Vulkan quantization infrastructure lacked support for the `torchao.quantize_affine` operator, which is essential for enabling dynamic quantization efficiently. The `quantize_affine` operator provides flexible block-wise quantization that allows different scale and zero-point values for tensor blocks, enabling:
* **Block-wise Quantization**: Applies quantization parameters to configurable tensor blocks rather than entire tensors, improving quantization accuracy for heterogeneous data distributions
* **Affine Transformation**: Uses the formula `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)` for precise floating-point to integer mapping
# Operator Description
The `quantize_affine` operator converts floating-point tensor values to n-bit integer representations using pre-computed quantization parameters (scale and zero_point) applied to configurable tensor blocks. Block-wise quantization divides tensors into blocks and applies separate quantization parameters to each block, allowing fine-grained control over quantization precision.
The quantization formula is: `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)`
**Storage Requirements**: Scale and zero_point tensors must use buffer storage with width-packed layout. Input/output tensors support both buffer and texture storage with standard axis mapping.
# Block-wise Quantization Implementation
Block-wise quantization enables fine-grained quantization by dividing tensors into blocks and applying separate quantization parameters to each block. The implementation uses several key data structures computed in `Quantize.cpp`:
* **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks)
* **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()`
* **`num_blocks_vec`**: Number of blocks per dimension calculated as `tensor_size_whcn / block_size_vec`
* **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation
The block coordinate calculation uses: `bcoord = tidx / blockSize` where `tidx` is the tensor coordinate in WHCN layout, then the linear block ID is computed as: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`
# Shader Algorithm Overview
## Texture Storage Implementation (`quantize_texture.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: Default sizing based on texture dimensions
- **Local WG Size**: Default with special handling for batch dimension quantization (Z dimension set to 1 for proper workgroup dispatching when `global_workgroup_size[2] > 1`)
**Block-wise Mode Algorithm**:
The shader processes 3D texture positions where each position represents a texel containing 4 width-packed components. For each texel at position `pos`, it calculates a base tensor index `base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0)` to account for width-packing.
For each of the 4 components in the texel, it computes the actual tensor coordinate: `tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total))` where `foldedZ = pos.z` handles batch-channel folding in 4D tensors and `C_total = numBlocks.z * blockSize.z` represents the total channel dimension.
The block coordinate is calculated using integer division: `bcoord = tidx / blockSize`, then the linear block ID uses pre-computed strides: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`.
Each component is quantized using its corresponding block's parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])` and written to the output texel.
## Buffer Storage Implementation (`quantize_buffer.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: Default sizing based on buffer element count
- **Local WG Size**: Default sizing without special constraints
**Block-wise Mode Algorithm**:
The shader processes linear buffer indices using `gl_GlobalInvocationID.x` as the output buffer index. It converts this to tensor coordinates using `bufi_to_tidx(out_bufi, t_out_strides, out_dim_order)` which handles the buffer-to-tensor index mapping with proper stride calculations.
For each element, it computes the block coordinate directly: `bcoord = out_tidx / blockSize` where `out_tidx` is the 4D tensor coordinate in WHCN layout. The linear block ID calculation uses the same pre-computed stride approach: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`.
The element value is loaded using the corresponding input buffer index: `value = t_in[in_bufi]` where `in_bufi = tidx_to_bufi(out_tidx, t_in_strides)`. Quantization applies the block-specific parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])`.
**Future Improvements**: Dynamic workgroup sizing based on block dimensions, there is likely a better method to making it better than what it is currently.
Differential Revision: [D78302195](https://our.internmc.facebook.com/intern/diff/D78302195/)
cc SS-JIA manuelcandales cbilgin
[ghstack-poisoned]File tree
37 files changed
+2893
-69
lines changed- backends
- nxp
- qualcomm
- tests
- vulkan
- quantizer
- runtime/graph
- ops
- glsl
- impl
- test/op_tests
- docs/source
- examples
- apple/coreml/llama
- models/llama
- qualcomm
- oss_scripts/t5
- runner
- scripts
- extension/llm/export
- scripts
- tools/cmake/preset
37 files changed
+2893
-69
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
51 | | - | |
52 | | - | |
53 | 51 | | |
54 | 52 | | |
55 | 53 | | |
| |||
82 | 80 | | |
83 | 81 | | |
84 | 82 | | |
| 83 | + | |
85 | 84 | | |
86 | 85 | | |
87 | 86 | | |
| |||
97 | 96 | | |
98 | 97 | | |
99 | 98 | | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | 99 | | |
106 | 100 | | |
107 | 101 | | |
| |||
750 | 744 | | |
751 | 745 | | |
752 | 746 | | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
11 | | - | |
| 11 | + | |
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
28 | | - | |
| 28 | + | |
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | | - | |
| 45 | + | |
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
62 | | - | |
| 62 | + | |
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| |||
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
91 | | - | |
92 | | - | |
93 | | - | |
| 91 | + | |
94 | 92 | | |
95 | 93 | | |
96 | 94 | | |
97 | 95 | | |
98 | 96 | | |
99 | 97 | | |
100 | 98 | | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
| 99 | + | |
106 | 100 | | |
107 | 101 | | |
108 | 102 | | |
109 | 103 | | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
| 104 | + | |
| 105 | + | |
114 | 106 | | |
115 | 107 | | |
116 | 108 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
174 | 174 | | |
175 | 175 | | |
176 | 176 | | |
177 | | - | |
| 177 | + | |
| 178 | + | |
178 | 179 | | |
179 | 180 | | |
180 | 181 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
178 | 178 | | |
179 | 179 | | |
180 | 180 | | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
181 | 186 | | |
182 | 187 | | |
183 | 188 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3384 | 3384 | | |
3385 | 3385 | | |
3386 | 3386 | | |
| 3387 | + | |
| 3388 | + | |
| 3389 | + | |
| 3390 | + | |
| 3391 | + | |
| 3392 | + | |
| 3393 | + | |
| 3394 | + | |
| 3395 | + | |
| 3396 | + | |
| 3397 | + | |
| 3398 | + | |
| 3399 | + | |
| 3400 | + | |
| 3401 | + | |
| 3402 | + | |
| 3403 | + | |
| 3404 | + | |
| 3405 | + | |
| 3406 | + | |
| 3407 | + | |
| 3408 | + | |
| 3409 | + | |
| 3410 | + | |
| 3411 | + | |
| 3412 | + | |
| 3413 | + | |
| 3414 | + | |
| 3415 | + | |
| 3416 | + | |
| 3417 | + | |
| 3418 | + | |
3387 | 3419 | | |
3388 | 3420 | | |
3389 | 3421 | | |
| |||
5022 | 5054 | | |
5023 | 5055 | | |
5024 | 5056 | | |
| 5057 | + | |
| 5058 | + | |
| 5059 | + | |
| 5060 | + | |
| 5061 | + | |
| 5062 | + | |
| 5063 | + | |
| 5064 | + | |
| 5065 | + | |
| 5066 | + | |
| 5067 | + | |
| 5068 | + | |
| 5069 | + | |
| 5070 | + | |
| 5071 | + | |
| 5072 | + | |
| 5073 | + | |
| 5074 | + | |
| 5075 | + | |
| 5076 | + | |
| 5077 | + | |
| 5078 | + | |
| 5079 | + | |
| 5080 | + | |
| 5081 | + | |
| 5082 | + | |
| 5083 | + | |
| 5084 | + | |
| 5085 | + | |
| 5086 | + | |
| 5087 | + | |
| 5088 | + | |
| 5089 | + | |
| 5090 | + | |
5025 | 5091 | | |
5026 | 5092 | | |
5027 | 5093 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
183 | 183 | | |
184 | 184 | | |
185 | 185 | | |
| 186 | + | |
186 | 187 | | |
187 | 188 | | |
188 | 189 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
693 | 693 | | |
694 | 694 | | |
695 | 695 | | |
| 696 | + | |
696 | 697 | | |
697 | 698 | | |
698 | 699 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
6 | | - | |
7 | | - | |
8 | | - | |
9 | | - | |
| 6 | + | |
| 7 | + | |
10 | 8 | | |
| 9 | + | |
11 | 10 | | |
12 | | - | |
13 | 11 | | |
14 | 12 | | |
15 | 13 | | |
16 | 14 | | |
17 | | - | |
18 | | - | |
19 | | - | |
20 | | - | |
| 15 | + | |
| 16 | + | |
21 | 17 | | |
22 | | - | |
23 | 18 | | |
24 | | - | |
25 | 19 | | |
26 | 20 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
273 | 273 | | |
274 | 274 | | |
275 | 275 | | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
276 | 284 | | |
277 | 285 | | |
278 | 286 | | |
| |||
Lines changed: 55 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
0 commit comments