Skip to content

Commit 0b1b6c9

Browse files
author
morelos
committed
Update base for Update on "[ET-VK][Ops] torchao.choose_qparams_affine vulkan impl and shader (buffer only) and cleanup"
# Changes * Implement `torchao.choose_qparams_affine` operator in Vulkan backend with comprehensive buffer storage support * Add block-wise quantization parameter computation in `choose_qparams_buffer.glsl` shader for configurable tensor block analysis * Extend quantization parameter infrastructure in `ChooseQParams.cpp` to handle affine transformations with configurable block sizes and multiple mapping types * Support three quantization mapping strategies: ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR for optimal parameter selection * Consolidated the logic for choosing scale and zero point between affine cases and regular quantized_decomposed cases. BE: Improved the documentation in the shader logic which is more detailed and clear # Motivation The existing Vulkan quantization infrastructure lacked support for the `torchao.choose_qparams_affine` operator, which is essential for computing optimal quantization parameters in dynamic quantization workflows. The `choose_qparams_affine` operator provides flexible block-wise parameter computation that analyzes statistical distributions within tensor blocks, enabling: * **Block-wise Parameter Computation**: Analyzes configurable tensor blocks to compute optimal scale and zero-point values, improving quantization accuracy for heterogeneous data distributions * **Multiple Mapping Types**: Supports ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR quantization strategies for different precision-performance trade-offs # Operator Description The `choose_qparams_affine` operator computes optimal quantization parameters (scale and zero_point) from floating-point tensor blocks using statistical analysis of data distributions. Block-wise parameter computation divides tensors into blocks and analyzes each block independently to determine the best quantization mapping for subsequent quantization operations. The parameter calculation varies by mapping type: - **ASYMMETRIC**: `scale = (max - min) / (quant_max - quant_min)`, `zero_point = quant_min - round(min / scale)` - **SYMMETRIC**: `scale = max_abs / ((quant_max - quant_min) / 2)`, `zero_point = midpoint` - **SYMMETRIC_NO_CLIPPING_ERR**: `scale = max(abs(min)/abs(quant_min), max/quant_max)`, `zero_point = midpoint` **Storage Requirements**: Input tensors must be floating-point (kFloat) with width-packed layout. Output scale/zero_point tensors use buffer storage. NOTE: Texture storage implementation is not supported due to complexity of block-wise coordinate mapping in 3D texture space. This will likely be necessary for better efficiency in the future. # Block-wise Parameter Computation Implementation Block-wise parameter computation enables fine-grained quantization analysis by dividing tensors into blocks and computing separate scale/zero_point parameters for each block. The implementation uses several key data structures computed in `ChooseQParams.cpp`: * **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks) * **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()` * **`num_blocks_vec`**: Number of blocks per dimension calculated as `ceil(tensor_size_whcn / block_size_vec)` to handle non-divisible dimensions * **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation * **`mapping_type`**: Integer encoding of quantization strategy (0=ASYMMETRIC, 1=SYMMETRIC, 2=SYMMETRIC_NO_CLIPPING_ERR) The block coordinate calculation uses: `block_coord = block_id_to_coord(block_id)` which converts linear block IDs back to 4D WHCN coordinates, then computes element ranges: `t0 = block_coord * blockSize` and `tEnd = t0 + blockSize` for nested loop iteration. # Shader Algorithm Overview ## Buffer Storage Implementation (`choose_qparams_buffer.glsl`) **Workgroup Configuration**: - **Global WG Size**: `{nBlocks, 1u, 1u}` where `nBlocks = total number of blocks` computed from `ceil(tensor_size / block_size)` for each dimension - **Local WG Size**: `{1u, 1u, 1u}` (single thread per block for simplicity, though could be optimized for larger blocks) **Block-wise Mode Algorithm**: The shader uses a sophisticated multi-level nested approach to process tensor blocks efficiently. Each thread is assigned multiple blocks using strided access: `for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE)` where `STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x`. For each assigned block, the algorithm performs several key steps: **1. Block Coordinate Conversion**: The `block_id_to_coord(block_id)` function converts linear block IDs to 4D WHCN coordinates using modular arithmetic. **2. Element Range Calculation**: Computes the inclusive start coordinate `t0 = bc * blockSize` and exclusive end coordinate `tEnd = t0 + blockSize` to define the block's element boundaries in tensor space. **3. Nested Loop Min/Max Scan**: Uses four nested loops to iterate through all elements within the block: `for (int n = t0.w; n < tEnd.w; ++n) for (int c = t0.z; c < tEnd.z; ++c) for (int h = t0.y; h < tEnd.y; ++h) for (int w = t0.x; w < tEnd.x; ++w)` Each element is accessed using `tidx_to_bufi(ivec4(w,h,c,n), t_in_strides)` to convert 4D tensor coordinates to linear buffer indices with proper stride handling. **4. Parameter Calculation**: Calls `calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp)` which implements the three mapping strategies: * **ASYMMETRIC (mapping_type=0)**: Maps full range [min, max] to [quant_min, quant_max] preserving data distribution * **SYMMETRIC (mapping_type=1)**: Centers around zero using `max_abs = max(abs(min), abs(max))` for balanced quantization * **SYMMETRIC_NO_CLIPPING_ERR (mapping_type=2)**: Computes separate scales for positive/negative ranges and uses the maximum to prevent clipping **Future Improvements**: Implement workgroup-level reduction for large blocks, optimize memory access patterns for better cache utilization, and explore texture storage implementation with simplified block alignment constraints. Differential Revision: [D78436638](https://our.internmc.facebook.com/intern/diff/D78436638/) cc SS-JIA manuelcandales cbilgin [ghstack-poisoned]
1 parent 0840efc commit 0b1b6c9

File tree

2 files changed

+97
-58
lines changed

2 files changed

+97
-58
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 97 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -215,57 +215,20 @@ def fuse_into_linear_qcnw_node(
215215
#########################
216216

217217

218-
def matches_linear_qta8a_qga4w_pattern(
219-
program: ExportedProgram, node: torch.fx.Node
220-
) -> Optional[Tuple[int, int]]:
221-
"""
222-
Checks if the nodes surrounding a linear node matches the pattern for dynamic
223-
activation + grouped weight quantized linear (QTA8A_QGA4W).
224-
225-
This pattern involves:
226-
1. Dynamic quantization of input activations (8-bit)
227-
2. Grouped quantization of weights (4-bit with group size)
228-
229-
The expected pattern from Int8DynActInt4WeightQuantizer is:
230-
scale, zero_point = choose_qparams_affine(input)
231-
quantized_input = quantize_affine(input, scale, zero_point)
232-
dequantized_input = dequantize_affine(quantized_input, ...)
233-
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
234-
output = linear(dequantized_input, dequantized_weight)
235-
236-
If the pattern matches, return (group_size, weight_bits), otherwise None.
237-
"""
238-
if not utils.is_linear_node(node):
239-
return None
240-
241-
input_node = node.args[0]
242-
weight_node = node.args[1]
243-
244-
# Type checking - ensure we have torch.fx.Node objects
245-
if not isinstance(weight_node, torch.fx.Node):
246-
return None
247-
if not isinstance(input_node, torch.fx.Node):
248-
return None
249-
250-
# Check if input is dequantized with dequantize_affine (from dynamic quantization)
251-
if not (
252-
input_node.op == "call_function"
253-
and input_node.target is not None
254-
and hasattr(input_node.target, "__name__")
255-
and "dequantize_affine" in getattr(input_node.target, "__name__", "")
256-
):
257-
return None
218+
def _is_dequantize_affine_node(node: torch.fx.Node) -> bool:
219+
"""Check if a node is a dequantize_affine function call."""
220+
return (
221+
node.op == "call_function"
222+
and node.target is not None
223+
and hasattr(node.target, "__name__")
224+
and "dequantize_affine" in getattr(node.target, "__name__", "")
225+
)
258226

259-
# Check if weight is dequantized with dequantize_affine
260-
if not (
261-
weight_node.op == "call_function"
262-
and weight_node.target is not None
263-
and hasattr(weight_node.target, "__name__")
264-
and "dequantize_affine" in getattr(weight_node.target, "__name__", "")
265-
):
266-
return None
267227

268-
# Get the original quantized weight and quantization parameters
228+
def _validate_qta8a_qga4w_nodes(
229+
program: ExportedProgram, weight_node: torch.fx.Node
230+
) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]:
231+
"""Validate and extract weight quantization nodes for QTA8A_QGA4W pattern."""
269232
if len(weight_node.args) < 4:
270233
return None
271234

@@ -287,7 +250,16 @@ def matches_linear_qta8a_qga4w_pattern(
287250
if not is_param_node(program, weight_zeros):
288251
return None
289252

290-
# Get tensors to analyze the quantization scheme
253+
return orig_weight, weight_scales, weight_zeros
254+
255+
256+
def _validate_qta8a_qga4w_tensors(
257+
program: ExportedProgram,
258+
orig_weight: torch.fx.Node,
259+
weight_scales: torch.fx.Node,
260+
weight_zeros: torch.fx.Node,
261+
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
262+
"""Validate and extract weight tensors for QTA8A_QGA4W pattern."""
291263
orig_weight_tensor = get_param_tensor(program, orig_weight)
292264
weight_scales_tensor = get_param_tensor(program, weight_scales)
293265
weight_zeros_tensor = get_param_tensor(program, weight_zeros)
@@ -299,20 +271,24 @@ def matches_linear_qta8a_qga4w_pattern(
299271
if not isinstance(weight_zeros_tensor, torch.Tensor):
300272
return None
301273

302-
# Check if weight is quantized to 4 bits (values should be in [-8, 7] range)
274+
return orig_weight_tensor, weight_scales_tensor, weight_zeros_tensor
275+
276+
277+
def _validate_4bit_quantization(orig_weight_tensor: torch.Tensor) -> bool:
278+
"""Check if weight tensor is quantized to 4 bits."""
303279
quant_min = orig_weight_tensor.min().item()
304280
quant_max = orig_weight_tensor.max().item()
281+
return quant_min >= -8 and quant_max <= 7
305282

306-
if not (quant_min >= -8 and quant_max <= 7):
307-
return None
308-
309-
# Determine group size from the scales tensor shape
310-
# For grouped quantization, scales shape should be [out_features, in_features // group_size]
311-
out_features, in_features = orig_weight_tensor.shape
312283

284+
def _calculate_group_size(
285+
orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor
286+
) -> Optional[int]:
287+
"""Calculate and validate group size from tensor shapes."""
313288
if len(weight_scales_tensor.shape) != 2:
314289
return None
315290

291+
out_features, in_features = orig_weight_tensor.shape
316292
scales_out_features, num_groups = weight_scales_tensor.shape
317293

318294
if scales_out_features != out_features:
@@ -322,6 +298,70 @@ def matches_linear_qta8a_qga4w_pattern(
322298
if in_features % group_size != 0:
323299
return None
324300

301+
return group_size
302+
303+
304+
def matches_linear_qta8a_qga4w_pattern(
305+
program: ExportedProgram, node: torch.fx.Node
306+
) -> Optional[Tuple[int, int]]:
307+
"""
308+
Checks if the nodes surrounding a linear node matches the pattern for dynamic
309+
activation + grouped weight quantized linear (QTA8A_QGA4W).
310+
311+
This pattern involves:
312+
1. Dynamic quantization of input activations (8-bit)
313+
2. Grouped quantization of weights (4-bit with group size)
314+
315+
The expected pattern from Int8DynActInt4WeightQuantizer is:
316+
scale, zero_point = choose_qparams_affine(input)
317+
quantized_input = quantize_affine(input, scale, zero_point)
318+
dequantized_input = dequantize_affine(quantized_input, ...)
319+
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
320+
output = linear(dequantized_input, dequantized_weight)
321+
322+
If the pattern matches, return (group_size, weight_bits), otherwise None.
323+
"""
324+
if not utils.is_linear_node(node):
325+
return None
326+
327+
input_node = node.args[0]
328+
weight_node = node.args[1]
329+
330+
# Type checking - ensure we have torch.fx.Node objects
331+
if not isinstance(weight_node, torch.fx.Node):
332+
return None
333+
if not isinstance(input_node, torch.fx.Node):
334+
return None
335+
336+
# Check if input and weight are dequantized with dequantize_affine
337+
if not _is_dequantize_affine_node(input_node):
338+
return None
339+
if not _is_dequantize_affine_node(weight_node):
340+
return None
341+
342+
# Validate and extract weight quantization nodes
343+
weight_nodes = _validate_qta8a_qga4w_nodes(program, weight_node)
344+
if weight_nodes is None:
345+
return None
346+
orig_weight, weight_scales, weight_zeros = weight_nodes
347+
348+
# Validate and extract weight tensors
349+
weight_tensors = _validate_qta8a_qga4w_tensors(
350+
program, orig_weight, weight_scales, weight_zeros
351+
)
352+
if weight_tensors is None:
353+
return None
354+
orig_weight_tensor, weight_scales_tensor, weight_zeros_tensor = weight_tensors
355+
356+
# Check if weight is quantized to 4 bits
357+
if not _validate_4bit_quantization(orig_weight_tensor):
358+
return None
359+
360+
# Calculate and validate group size
361+
group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor)
362+
if group_size is None:
363+
return None
364+
325365
# Verify this is 4-bit grouped quantization
326366
weight_bits = 4
327367

backends/vulkan/custom_ops_lib.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def linear_qta8a_qga4w(
258258
weight_zeros: Per-group zero points for weights
259259
"""
260260
original_x_shape = x_quantized.shape
261-
batch_size = original_x_shape[0]
262261
feature_dim = original_x_shape[-1]
263262

264263
# Reshape for processing

0 commit comments

Comments
 (0)