Commit 7076257
morelos
Update on "[ET-VK][Ops] torchao.choose_qparams_affine vulkan impl and shader (buffer only) and cleanup"
# Changes
* Implement `torchao.choose_qparams_affine` operator in Vulkan backend with comprehensive buffer storage support
* Add block-wise quantization parameter computation in `choose_qparams_buffer.glsl` shader for configurable tensor block analysis
* Extend quantization parameter infrastructure in `ChooseQParams.cpp` to handle affine transformations with configurable block sizes and multiple mapping types
* Support three quantization mapping strategies: ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR for optimal parameter selection
* Consolidated the logic for choosing scale and zero point between affine cases and regular quantized_decomposed cases.
BE: Improved the documentation in the shader logic which is more detailed and clear
# Motivation
The existing Vulkan quantization infrastructure lacked support for the `torchao.choose_qparams_affine` operator, which is essential for computing optimal quantization parameters in dynamic quantization workflows. The `choose_qparams_affine` operator provides flexible block-wise parameter computation that analyzes statistical distributions within tensor blocks, enabling:
* **Block-wise Parameter Computation**: Analyzes configurable tensor blocks to compute optimal scale and zero-point values, improving quantization accuracy for heterogeneous data distributions
* **Multiple Mapping Types**: Supports ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR quantization strategies for different precision-performance trade-offs
# Operator Description
The `choose_qparams_affine` operator computes optimal quantization parameters (scale and zero_point) from floating-point tensor blocks using statistical analysis of data distributions. Block-wise parameter computation divides tensors into blocks and analyzes each block independently to determine the best quantization mapping for subsequent quantization operations.
The parameter calculation varies by mapping type:
- **ASYMMETRIC**: `scale = (max - min) / (quant_max - quant_min)`, `zero_point = quant_min - round(min / scale)`
- **SYMMETRIC**: `scale = max_abs / ((quant_max - quant_min) / 2)`, `zero_point = midpoint`
- **SYMMETRIC_NO_CLIPPING_ERR**: `scale = max(abs(min)/abs(quant_min), max/quant_max)`, `zero_point = midpoint`
**Storage Requirements**: Input tensors must be floating-point (kFloat) with width-packed layout. Output scale/zero_point tensors use buffer storage.
NOTE: Texture storage implementation is not supported due to complexity of block-wise coordinate mapping in 3D texture space. This will likely be necessary for better efficiency in the future.
# Block-wise Parameter Computation Implementation
Block-wise parameter computation enables fine-grained quantization analysis by dividing tensors into blocks and computing separate scale/zero_point parameters for each block. The implementation uses several key data structures computed in `ChooseQParams.cpp`:
* **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks)
* **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()`
* **`num_blocks_vec`**: Number of blocks per dimension calculated as `ceil(tensor_size_whcn / block_size_vec)` to handle non-divisible dimensions
* **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation
* **`mapping_type`**: Integer encoding of quantization strategy (0=ASYMMETRIC, 1=SYMMETRIC, 2=SYMMETRIC_NO_CLIPPING_ERR)
The block coordinate calculation uses: `block_coord = block_id_to_coord(block_id)` which converts linear block IDs back to 4D WHCN coordinates, then computes element ranges: `t0 = block_coord * blockSize` and `tEnd = t0 + blockSize` for nested loop iteration.
# Shader Algorithm Overview
## Buffer Storage Implementation (`choose_qparams_buffer.glsl`)
**Workgroup Configuration**:
- **Global WG Size**: `{nBlocks, 1u, 1u}` where `nBlocks = total number of blocks` computed from `ceil(tensor_size / block_size)` for each dimension
- **Local WG Size**: `{1u, 1u, 1u}` (single thread per block for simplicity, though could be optimized for larger blocks)
**Block-wise Mode Algorithm**:
The shader uses a sophisticated multi-level nested approach to process tensor blocks efficiently. Each thread is assigned multiple blocks using strided access: `for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE)` where `STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x`.
For each assigned block, the algorithm performs several key steps:
**1. Block Coordinate Conversion**:
The `block_id_to_coord(block_id)` function converts linear block IDs to 4D WHCN coordinates using modular arithmetic.
**2. Element Range Calculation**: Computes the inclusive start coordinate `t0 = bc * blockSize` and exclusive end coordinate `tEnd = t0 + blockSize` to define the block's element boundaries in tensor space.
**3. Nested Loop Min/Max Scan**: Uses four nested loops to iterate through all elements within the block:
`for (int n = t0.w; n < tEnd.w; ++n) for (int c = t0.z; c < tEnd.z; ++c) for (int h = t0.y; h < tEnd.y; ++h) for (int w = t0.x; w < tEnd.x; ++w)`
Each element is accessed using `tidx_to_bufi(ivec4(w,h,c,n), t_in_strides)` to convert 4D tensor coordinates to linear buffer indices with proper stride handling.
**4. Parameter Calculation**: Calls `calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp)` which implements the three mapping strategies:
* **ASYMMETRIC (mapping_type=0)**: Maps full range [min, max] to [quant_min, quant_max] preserving data distribution
* **SYMMETRIC (mapping_type=1)**: Centers around zero using `max_abs = max(abs(min), abs(max))` for balanced quantization
* **SYMMETRIC_NO_CLIPPING_ERR (mapping_type=2)**: Computes separate scales for positive/negative ranges and uses the maximum to prevent clipping
**Future Improvements**: Implement workgroup-level reduction for large blocks, optimize memory access patterns for better cache utilization, and explore texture storage implementation with simplified block alignment constraints.
Differential Revision: [D78436638](https://our.internmc.facebook.com/intern/diff/D78436638/)
cc SS-JIA manuelcandales cbilgin
[ghstack-poisoned]1 parent 4b61f7e commit 7076257
1 file changed
+9
-32
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
279 | 279 | | |
280 | 280 | | |
281 | 281 | | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
288 | | - | |
289 | | - | |
290 | | - | |
291 | 282 | | |
292 | | - | |
293 | | - | |
294 | | - | |
295 | | - | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
296 | 286 | | |
297 | 287 | | |
298 | 288 | | |
299 | | - | |
300 | 289 | | |
301 | 290 | | |
302 | 291 | | |
303 | | - | |
| 292 | + | |
304 | 293 | | |
305 | 294 | | |
306 | 295 | | |
307 | 296 | | |
308 | 297 | | |
309 | | - | |
310 | 298 | | |
311 | 299 | | |
312 | | - | |
313 | | - | |
314 | 300 | | |
315 | | - | |
| 301 | + | |
316 | 302 | | |
317 | 303 | | |
318 | 304 | | |
| |||
321 | 307 | | |
322 | 308 | | |
323 | 309 | | |
324 | | - | |
| 310 | + | |
325 | 311 | | |
326 | 312 | | |
327 | 313 | | |
| |||
331 | 317 | | |
332 | 318 | | |
333 | 319 | | |
334 | | - | |
| 320 | + | |
335 | 321 | | |
336 | 322 | | |
337 | 323 | | |
338 | | - | |
| 324 | + | |
339 | 325 | | |
340 | 326 | | |
341 | 327 | | |
| |||
344 | 330 | | |
345 | 331 | | |
346 | 332 | | |
347 | | - | |
348 | | - | |
349 | 333 | | |
350 | 334 | | |
351 | 335 | | |
352 | 336 | | |
353 | 337 | | |
354 | | - | |
355 | | - | |
356 | 338 | | |
357 | 339 | | |
358 | 340 | | |
359 | | - | |
360 | 341 | | |
361 | 342 | | |
362 | 343 | | |
| |||
403 | 384 | | |
404 | 385 | | |
405 | 386 | | |
406 | | - | |
| 387 | + | |
407 | 388 | | |
408 | 389 | | |
409 | 390 | | |
410 | 391 | | |
411 | 392 | | |
412 | 393 | | |
413 | | - | |
414 | | - | |
415 | 394 | | |
416 | 395 | | |
417 | 396 | | |
| |||
1007 | 986 | | |
1008 | 987 | | |
1009 | 988 | | |
1010 | | - | |
1011 | 989 | | |
1012 | 990 | | |
1013 | 991 | | |
| |||
1033 | 1011 | | |
1034 | 1012 | | |
1035 | 1013 | | |
1036 | | - | |
1037 | 1014 | | |
1038 | 1015 | | |
1039 | 1016 | | |
| |||
0 commit comments