Commit 587db7d
morelos
Update on "[ET-VK][Ops] torchao.quantize_affine vulkan impl and shader and cleanup"
# Changes
* Implement `torchao.quantize_affine` operator in Vulkan backend with comprehensive texture and buffer storage support
* Add block-wise quantization mode in `quantize_texture.glsl` and `quantize_buffer.glsl` shaders for configurable tensor block quantization
* Introduce comprehensive test suite in `affine_test.cpp` with multi-dimensional tensor validation and reference implementation
* Extend quantization infrastructure in `Quantize.cpp` to handle affine transformations with configurable block sizes and quantization parameters
BE: Improved the documentation in the shader logic which is more detailed and clear
NOTE: I delegated the quantize_affine and future affine operators through a new custom test file denoted as `affine_test.cpp` as the other quantization testing framework was getting a little large, and it makes more sense to separate the namespace between torchao and quantized_decomposed. I believe the _decomposed namespace is getting phased out in favor of this affine operator so deprecation will be easier in the future.
# Motivation
The existing Vulkan quantization infrastructure lacked support for the `torchao.quantize_affine` operator, which is essential for enabling dynamic quantization efficiently. The `quantize_affine` operator provides flexible block-wise quantization that allows different scale and zero-point values for tensor blocks, enabling:
* **Block-wise Quantization**: Applies quantization parameters to configurable tensor blocks rather than entire tensors, improving quantization accuracy for heterogeneous data distributions
* **Affine Transformation**: Uses the formula `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)` for precise floating-point to integer mapping
# Operator Description
The `quantize_affine` operator converts floating-point tensor values to n-bit integer representations using pre-computed quantization parameters (scale and zero_point) applied to configurable tensor blocks. Block-wise quantization divides tensors into blocks and applies separate quantization parameters to each block, allowing fine-grained control over quantization precision.
The quantization formula is: `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)`
**Storage Requirements**: Scale and zero_point tensors must use buffer storage with width-packed layout. Input/output tensors support both buffer and texture storage with standard axis mapping.
# Block-wise Quantization Implementation
Block-wise quantization enables fine-grained quantization by dividing tensors into blocks and applying separate quantization parameters to each block. The implementation uses several key data structures computed in `Quantize.cpp`:
* **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks)
* **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()`
* **`num_blocks_vec`**: Number of blocks per dimension calculated as `tensor_size_whcn / block_size_vec`
* **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation
The block coordinate calculation uses: `bcoord = tidx / blockSize` where `tidx` is the tensor coordinate in WHCN layout, then the linear block ID is computed as: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`
# Shader Algorithm Overview
## Texture Storage Implementation (`quantize_texture.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: Default sizing based on texture dimensions
- **Local WG Size**: Default with special handling for batch dimension quantization (Z dimension set to 1 for proper workgroup dispatching when `global_workgroup_size[2] > 1`)
**Block-wise Mode Algorithm**:
The shader processes 3D texture positions where each position represents a texel containing 4 width-packed components. For each texel at position `pos`, it calculates a base tensor index `base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0)` to account for width-packing.
For each of the 4 components in the texel, it computes the actual tensor coordinate: `tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total))` where `foldedZ = pos.z` handles batch-channel folding in 4D tensors and `C_total = numBlocks.z * blockSize.z` represents the total channel dimension.
The block coordinate is calculated using integer division: `bcoord = tidx / blockSize`, then the linear block ID uses pre-computed strides: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`.
Each component is quantized using its corresponding block's parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])` and written to the output texel.
## Buffer Storage Implementation (`quantize_buffer.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: Default sizing based on buffer element count
- **Local WG Size**: Default sizing without special constraints
**Block-wise Mode Algorithm**:
The shader processes linear buffer indices using `gl_GlobalInvocationID.x` as the output buffer index. It converts this to tensor coordinates using `bufi_to_tidx(out_bufi, t_out_strides, out_dim_order)` which handles the buffer-to-tensor index mapping with proper stride calculations.
For each element, it computes the block coordinate directly: `bcoord = out_tidx / blockSize` where `out_tidx` is the 4D tensor coordinate in WHCN layout. The linear block ID calculation uses the same pre-computed stride approach: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`.
The element value is loaded using the corresponding input buffer index: `value = t_in[in_bufi]` where `in_bufi = tidx_to_bufi(out_tidx, t_in_strides)`. Quantization applies the block-specific parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])`.
**Future Improvements**: Dynamic workgroup sizing based on block dimensions, there is likely a better method to making it better than what it is currently.
Differential Revision: [D78302195](https://our.internmc.facebook.com/intern/diff/D78302195/)
cc SS-JIA manuelcandales cbilgin
[ghstack-poisoned]File tree
42 files changed
+1088
-105
lines changed- .github/workflows
- backends
- arm
- _passes
- test
- ops
- tester
- cadence
- aot
- tests
- hifi/operators
- tests
- nxp/tests
- ir/converter/node_converter
- vulkan/runtime
- api
- graph
- ops/glsl
- vk_api
- xnnpack
- partition/config
- test
- ops
- passes
- examples/models/llama
- exir/program
- extension
- apple/ExecuTorch/Exported
- llm
- runner/io_manager
- test
- scripts
- tools/cmake
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
42 files changed
+1088
-105
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
302 | 302 | | |
303 | 303 | | |
304 | 304 | | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
305 | 336 | | |
306 | 337 | | |
307 | 338 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
18 | | - | |
| 18 | + | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
85 | 85 | | |
86 | 86 | | |
87 | 87 | | |
| 88 | + | |
| 89 | + | |
88 | 90 | | |
89 | 91 | | |
90 | 92 | | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | 93 | | |
95 | 94 | | |
96 | 95 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
205 | 205 | | |
206 | 206 | | |
207 | 207 | | |
208 | | - | |
209 | 208 | | |
210 | 209 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| 46 | + | |
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| |||
332 | 333 | | |
333 | 334 | | |
334 | 335 | | |
| 336 | + | |
| 337 | + | |
335 | 338 | | |
336 | 339 | | |
337 | 340 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
861 | 861 | | |
862 | 862 | | |
863 | 863 | | |
| 864 | + | |
| 865 | + | |
| 866 | + | |
864 | 867 | | |
865 | 868 | | |
866 | 869 | | |
867 | 870 | | |
868 | 871 | | |
869 | 872 | | |
870 | | - | |
871 | | - | |
872 | | - | |
873 | | - | |
874 | | - | |
875 | | - | |
876 | 873 | | |
877 | 874 | | |
878 | 875 | | |
| |||
887 | 884 | | |
888 | 885 | | |
889 | 886 | | |
| 887 | + | |
890 | 888 | | |
891 | 889 | | |
892 | 890 | | |
| |||
900 | 898 | | |
901 | 899 | | |
902 | 900 | | |
903 | | - | |
| 901 | + | |
904 | 902 | | |
905 | 903 | | |
906 | 904 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
539 | 539 | | |
540 | 540 | | |
541 | 541 | | |
| 542 | + | |
542 | 543 | | |
543 | 544 | | |
| 545 | + | |
544 | 546 | | |
545 | 547 | | |
546 | 548 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
23 | 26 | | |
24 | 27 | | |
25 | 28 | | |
| |||
95 | 98 | | |
96 | 99 | | |
97 | 100 | | |
98 | | - | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
99 | 104 | | |
100 | 105 | | |
101 | 106 | | |
| |||
169 | 174 | | |
170 | 175 | | |
171 | 176 | | |
172 | | - | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
173 | 180 | | |
174 | 181 | | |
175 | 182 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
16 | 15 | | |
17 | 16 | | |
18 | 17 | | |
| |||
30 | 29 | | |
31 | 30 | | |
32 | 31 | | |
33 | | - | |
34 | 32 | | |
35 | 33 | | |
36 | 34 | | |
| |||
178 | 176 | | |
179 | 177 | | |
180 | 178 | | |
181 | | - | |
182 | | - | |
183 | | - | |
184 | | - | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | | - | |
193 | | - | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | 179 | | |
219 | 180 | | |
220 | 181 | | |
| |||
0 commit comments