Commit 0b1b6c9
morelos
Update base for Update on "[ET-VK][Ops] torchao.choose_qparams_affine vulkan impl and shader (buffer only) and cleanup"
# Changes
* Implement `torchao.choose_qparams_affine` operator in Vulkan backend with comprehensive buffer storage support
* Add block-wise quantization parameter computation in `choose_qparams_buffer.glsl` shader for configurable tensor block analysis
* Extend quantization parameter infrastructure in `ChooseQParams.cpp` to handle affine transformations with configurable block sizes and multiple mapping types
* Support three quantization mapping strategies: ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR for optimal parameter selection
* Consolidated the logic for choosing scale and zero point between affine cases and regular quantized_decomposed cases.
BE: Improved the documentation in the shader logic which is more detailed and clear
# Motivation
The existing Vulkan quantization infrastructure lacked support for the `torchao.choose_qparams_affine` operator, which is essential for computing optimal quantization parameters in dynamic quantization workflows. The `choose_qparams_affine` operator provides flexible block-wise parameter computation that analyzes statistical distributions within tensor blocks, enabling:
* **Block-wise Parameter Computation**: Analyzes configurable tensor blocks to compute optimal scale and zero-point values, improving quantization accuracy for heterogeneous data distributions
* **Multiple Mapping Types**: Supports ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR quantization strategies for different precision-performance trade-offs
# Operator Description
The `choose_qparams_affine` operator computes optimal quantization parameters (scale and zero_point) from floating-point tensor blocks using statistical analysis of data distributions. Block-wise parameter computation divides tensors into blocks and analyzes each block independently to determine the best quantization mapping for subsequent quantization operations.
The parameter calculation varies by mapping type:
- **ASYMMETRIC**: `scale = (max - min) / (quant_max - quant_min)`, `zero_point = quant_min - round(min / scale)`
- **SYMMETRIC**: `scale = max_abs / ((quant_max - quant_min) / 2)`, `zero_point = midpoint`
- **SYMMETRIC_NO_CLIPPING_ERR**: `scale = max(abs(min)/abs(quant_min), max/quant_max)`, `zero_point = midpoint`
**Storage Requirements**: Input tensors must be floating-point (kFloat) with width-packed layout. Output scale/zero_point tensors use buffer storage.
NOTE: Texture storage implementation is not supported due to complexity of block-wise coordinate mapping in 3D texture space. This will likely be necessary for better efficiency in the future.
# Block-wise Parameter Computation Implementation
Block-wise parameter computation enables fine-grained quantization analysis by dividing tensors into blocks and computing separate scale/zero_point parameters for each block. The implementation uses several key data structures computed in `ChooseQParams.cpp`:
* **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks)
* **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()`
* **`num_blocks_vec`**: Number of blocks per dimension calculated as `ceil(tensor_size_whcn / block_size_vec)` to handle non-divisible dimensions
* **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation
* **`mapping_type`**: Integer encoding of quantization strategy (0=ASYMMETRIC, 1=SYMMETRIC, 2=SYMMETRIC_NO_CLIPPING_ERR)
The block coordinate calculation uses: `block_coord = block_id_to_coord(block_id)` which converts linear block IDs back to 4D WHCN coordinates, then computes element ranges: `t0 = block_coord * blockSize` and `tEnd = t0 + blockSize` for nested loop iteration.
# Shader Algorithm Overview
## Buffer Storage Implementation (`choose_qparams_buffer.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: `{nBlocks, 1u, 1u}` where `nBlocks = total number of blocks` computed from `ceil(tensor_size / block_size)` for each dimension
- **Local WG Size**: `{1u, 1u, 1u}` (single thread per block for simplicity, though could be optimized for larger blocks)
**Block-wise Mode Algorithm**:
The shader uses a sophisticated multi-level nested approach to process tensor blocks efficiently. Each thread is assigned multiple blocks using strided access: `for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE)` where `STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x`.
For each assigned block, the algorithm performs several key steps:
**1. Block Coordinate Conversion**:
The `block_id_to_coord(block_id)` function converts linear block IDs to 4D WHCN coordinates using modular arithmetic.
**2. Element Range Calculation**: Computes the inclusive start coordinate `t0 = bc * blockSize` and exclusive end coordinate `tEnd = t0 + blockSize` to define the block's element boundaries in tensor space.
**3. Nested Loop Min/Max Scan**: Uses four nested loops to iterate through all elements within the block:
`for (int n = t0.w; n < tEnd.w; ++n) for (int c = t0.z; c < tEnd.z; ++c) for (int h = t0.y; h < tEnd.y; ++h) for (int w = t0.x; w < tEnd.x; ++w)`
Each element is accessed using `tidx_to_bufi(ivec4(w,h,c,n), t_in_strides)` to convert 4D tensor coordinates to linear buffer indices with proper stride handling.
**4. Parameter Calculation**: Calls `calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp)` which implements the three mapping strategies:
* **ASYMMETRIC (mapping_type=0)**: Maps full range [min, max] to [quant_min, quant_max] preserving data distribution
* **SYMMETRIC (mapping_type=1)**: Centers around zero using `max_abs = max(abs(min), abs(max))` for balanced quantization
* **SYMMETRIC_NO_CLIPPING_ERR (mapping_type=2)**: Computes separate scales for positive/negative ranges and uses the maximum to prevent clipping
**Future Improvements**: Implement workgroup-level reduction for large blocks, optimize memory access patterns for better cache utilization, and explore texture storage implementation with simplified block alignment constraints.
Differential Revision: [D78436638](https://our.internmc.facebook.com/intern/diff/D78436638/)
cc SS-JIA manuelcandales cbilgin
[ghstack-poisoned]1 parent 0840efc commit 0b1b6c9
File tree
2 files changed
+97
-58
lines changed- backends/vulkan
- _passes
2 files changed
+97
-58
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
215 | 215 | | |
216 | 216 | | |
217 | 217 | | |
218 | | - | |
219 | | - | |
220 | | - | |
221 | | - | |
222 | | - | |
223 | | - | |
224 | | - | |
225 | | - | |
226 | | - | |
227 | | - | |
228 | | - | |
229 | | - | |
230 | | - | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
244 | | - | |
245 | | - | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
251 | | - | |
252 | | - | |
253 | | - | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
258 | 226 | | |
259 | | - | |
260 | | - | |
261 | | - | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | 227 | | |
268 | | - | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
269 | 232 | | |
270 | 233 | | |
271 | 234 | | |
| |||
287 | 250 | | |
288 | 251 | | |
289 | 252 | | |
290 | | - | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
291 | 263 | | |
292 | 264 | | |
293 | 265 | | |
| |||
299 | 271 | | |
300 | 272 | | |
301 | 273 | | |
302 | | - | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
303 | 279 | | |
304 | 280 | | |
| 281 | + | |
305 | 282 | | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
311 | | - | |
312 | 283 | | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
313 | 288 | | |
314 | 289 | | |
315 | 290 | | |
| 291 | + | |
316 | 292 | | |
317 | 293 | | |
318 | 294 | | |
| |||
322 | 298 | | |
323 | 299 | | |
324 | 300 | | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
325 | 365 | | |
326 | 366 | | |
327 | 367 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
258 | 258 | | |
259 | 259 | | |
260 | 260 | | |
261 | | - | |
262 | 261 | | |
263 | 262 | | |
264 | 263 | | |
| |||
0 commit comments