Commit 7667af4
morelos
[ET-VK][Ops] torchao.quantize_affine vulkan impl and shader and cleanup
Pull Request resolved: #12575
# Changes
* Implement `torchao.quantize_affine` operator in Vulkan backend with comprehensive texture and buffer storage support
* Add block-wise quantization mode in `quantize_texture.glsl` and `quantize_buffer.glsl` shaders for configurable tensor block quantization
* Introduce comprehensive test suite in `affine_test.cpp` with multi-dimensional tensor validation and reference implementation
* Extend quantization infrastructure in `Quantize.cpp` to handle affine transformations with configurable block sizes and quantization parameters
BE: Improved the documentation in the shader logic which is more detailed and clear
NOTE: I delegated the quantize_affine and future affine operators through a new custom test file denoted as `affine_test.cpp` as the other quantization testing framework was getting a little large, and it makes more sense to separate the namespace between torchao and quantized_decomposed. I believe the _decomposed namespace is getting phased out in favor of this affine operator so deprecation will be easier in the future.
# Motivation
The existing Vulkan quantization infrastructure lacked support for the `torchao.quantize_affine` operator, which is essential for enabling dynamic quantization efficiently. The `quantize_affine` operator provides flexible block-wise quantization that allows different scale and zero-point values for tensor blocks, enabling:
* **Block-wise Quantization**: Applies quantization parameters to configurable tensor blocks rather than entire tensors, improving quantization accuracy for heterogeneous data distributions
* **Affine Transformation**: Uses the formula `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)` for precise floating-point to integer mapping
# Operator Description
The `quantize_affine` operator converts floating-point tensor values to n-bit integer representations using pre-computed quantization parameters (scale and zero_point) applied to configurable tensor blocks. Block-wise quantization divides tensors into blocks and applies separate quantization parameters to each block, allowing fine-grained control over quantization precision.
The quantization formula is: `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)`
**Storage Requirements**: Scale and zero_point tensors must use buffer storage with width-packed layout. Input/output tensors support both buffer and texture storage with standard axis mapping.
# Block-wise Quantization Implementation
Block-wise quantization enables fine-grained quantization by dividing tensors into blocks and applying separate quantization parameters to each block. The implementation uses several key data structures computed in `Quantize.cpp`:
* **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks)
* **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()`
* **`num_blocks_vec`**: Number of blocks per dimension calculated as `tensor_size_whcn / block_size_vec`
* **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation
The block coordinate calculation uses: `bcoord = tidx / blockSize` where `tidx` is the tensor coordinate in WHCN layout, then the linear block ID is computed as: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`
# Shader Algorithm Overview
## Texture Storage Implementation (`quantize_texture.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: Default sizing based on texture dimensions
- **Local WG Size**: Default with special handling for batch dimension quantization (Z dimension set to 1 for proper workgroup dispatching when `global_workgroup_size[2] > 1`)
**Block-wise Mode Algorithm**:
The shader processes 3D texture positions where each position represents a texel containing 4 width-packed components. For each texel at position `pos`, it calculates a base tensor index `base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0)` to account for width-packing.
For each of the 4 components in the texel, it computes the actual tensor coordinate: `tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total))` where `foldedZ = pos.z` handles batch-channel folding in 4D tensors and `C_total = numBlocks.z * blockSize.z` represents the total channel dimension.
The block coordinate is calculated using integer division: `bcoord = tidx / blockSize`, then the linear block ID uses pre-computed strides: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`.
Each component is quantized using its corresponding block's parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])` and written to the output texel.
## Buffer Storage Implementation (`quantize_buffer.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: Default sizing based on buffer element count
- **Local WG Size**: Default sizing without special constraints
**Block-wise Mode Algorithm**:
The shader processes linear buffer indices using `gl_GlobalInvocationID.x` as the output buffer index. It converts this to tensor coordinates using `bufi_to_tidx(out_bufi, t_out_strides, out_dim_order)` which handles the buffer-to-tensor index mapping with proper stride calculations.
For each element, it computes the block coordinate directly: `bcoord = out_tidx / blockSize` where `out_tidx` is the 4D tensor coordinate in WHCN layout. The linear block ID calculation uses the same pre-computed stride approach: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`.
The element value is loaded using the corresponding input buffer index: `value = t_in[in_bufi]` where `in_bufi = tidx_to_bufi(out_tidx, t_in_strides)`. Quantization applies the block-specific parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])`.
**Future Improvements**: Dynamic workgroup sizing based on block dimensions, there is likely a better method to making it better than what it is currently.
ghstack-source-id: 299473617
@exported-using-ghexport
Differential Revision: [D78302195](https://our.internmc.facebook.com/intern/diff/D78302195/)1 parent 272393c commit 7667af4
File tree
7 files changed
+798
-149
lines changed- backends/vulkan
- runtime/graph/ops
- glsl
- impl
- test/op_tests
7 files changed
+798
-149
lines changedLines changed: 83 additions & 59 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
56 | 67 | | |
57 | 68 | | |
58 | 69 | | |
| |||
71 | 82 | | |
72 | 83 | | |
73 | 84 | | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
125 | | - | |
126 | | - | |
127 | | - | |
128 | | - | |
129 | | - | |
130 | | - | |
131 | | - | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
132 | 133 | | |
133 | 134 | | |
134 | 135 | | |
| |||
183 | 184 | | |
184 | 185 | | |
185 | 186 | | |
186 | | - | |
| 187 | + | |
187 | 188 | | |
188 | 189 | | |
189 | 190 | | |
| |||
222 | 223 | | |
223 | 224 | | |
224 | 225 | | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
225 | 249 | | |
226 | 250 | | |
227 | 251 | | |
| |||
Lines changed: 2 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| 22 | + | |
| 23 | + | |
Lines changed: 94 additions & 63 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
58 | 58 | | |
59 | 59 | | |
60 | 60 | | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
61 | 72 | | |
62 | 73 | | |
63 | 74 | | |
| |||
70 | 81 | | |
71 | 82 | | |
72 | 83 | | |
73 | | - | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
125 | | - | |
126 | | - | |
127 | | - | |
128 | | - | |
129 | | - | |
130 | | - | |
131 | | - | |
132 | | - | |
133 | | - | |
134 | | - | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
135 | 136 | | |
136 | 137 | | |
137 | 138 | | |
| |||
192 | 193 | | |
193 | 194 | | |
194 | 195 | | |
195 | | - | |
| 196 | + | |
196 | 197 | | |
197 | 198 | | |
198 | 199 | | |
| |||
270 | 271 | | |
271 | 272 | | |
272 | 273 | | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
273 | 304 | | |
274 | 305 | | |
275 | 306 | | |
| |||
Lines changed: 2 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| 22 | + | |
| 23 | + | |
0 commit comments