Skip to content

Commit d691080

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] torchao.dequantize_affine vulkan impl and shader and cleanup"
# Changes * Implement `torchao.dequantize_affine` operator in Vulkan backend with comprehensive texture and buffer storage support * Add block-wise dequantization mode in `dequantize_texture.glsl` and `dequantize_buffer.glsl` shaders for configurable tensor block dequantization * Extend dequantization infrastructure in `Dequantize.cpp` to handle affine transformations with configurable block sizes and quantization parameters * Support integer-to-floating-point conversion with precise reconstruction of original values BE: Improved the documentation in the shader logic which is more detailed and clear # Motivation The existing Vulkan quantization infrastructure lacked support for the `torchao.dequantize_affine` operator, which is essential for completing the quantization-dequantization cycle in dynamic quantization workflows. The `dequantize_affine` operator provides flexible block-wise dequantization that reconstructs floating-point values from quantized integer blocks, enabling: * **Block-wise Dequantization**: Reconstructs floating-point values from configurable tensor blocks using separate scale and zero-point parameters, enabling precise recovery of original data distributions * **Affine Transformation**: Uses the formula `value = (qvalue - zero_point) * scale` for accurate integer-to-floating-point mapping * **TorchAO Integration**: Seamless compatibility with TorchAO quantization workflows and completes the quantization-dequantization round-trip # Operator Description The `dequantize_affine` operator converts n-bit integer tensor values back to floating-point representations using pre-computed quantization parameters (scale and zero_point) applied to configurable tensor blocks. Block-wise dequantization divides tensors into blocks and applies separate dequantization parameters to each block, allowing fine-grained reconstruction of the original floating-point precision. The dequantization formula is: `value = (qvalue - zero_point) * scale` **Storage Requirements**: Scale and zero_point tensors must use buffer storage with width-packed layout. Input/output tensors support both buffer and texture storage with standard axis mapping. Input tensors must be integer types (kByte, kChar, kInt). # Block-wise Dequantization Implementation Block-wise dequantization enables fine-grained reconstruction by dividing tensors into blocks and applying separate dequantization parameters to each block. The implementation uses the same key data structures computed in `Dequantize.cpp`: * **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks) * **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()` * **`num_blocks_vec`**: Number of blocks per dimension calculated as `tensor_size_whcn / block_size_vec` * **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation The block coordinate calculation uses: `bcoord = tidx / blockSize` where `tidx` is the tensor coordinate in WHCN layout, then the linear block ID is computed as: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w` # Shader Algorithm Overview ## Texture Storage Implementation (`dequantize_texture.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on texture dimensions - **Local WG Size**: Default with special handling for batch dimension dequantization (Z dimension set to 1 for proper workgroup dispatching when `global_workgroup_size[2] > 1`) **Block-wise Mode Algorithm**: The shader processes 3D texture positions where each position represents a texel containing 4 width-packed integer components. For each texel at position `pos`, it calculates a base tensor index `base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0)` to account for width-packing. For each of the 4 components in the texel, it computes the actual tensor coordinate: `tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total))` where `foldedZ = pos.z` handles batch-channel folding in 4D tensors and `C_total = numBlocks.z * blockSize.z` represents the total channel dimension. The block coordinate is calculated using integer division: `bcoord = tidx / blockSize`, then the linear block ID uses pre-computed strides: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. Each integer component is dequantized using its corresponding block's parameters: `value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id])` where `dequantize_val()` applies the formula `(qvalue - zero_point) * scale`. The reconstructed floating-point values are written to the output texel with proper type handling for double precision outputs. ## Buffer Storage Implementation (`dequantize_buffer.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on buffer element count - **Local WG Size**: Default sizing without special constraints **Block-wise Mode Algorithm**: The shader processes linear buffer indices using `gl_GlobalInvocationID.x` as the output buffer index. It converts this to tensor coordinates using `bufi_to_tidx(out_bufi, t_out_strides, out_dim_order)` which handles the buffer-to-tensor index mapping with proper stride calculations. For each element, it computes the block coordinate directly: `bcoord = out_tidx / blockSize` where `out_tidx` is the 4D tensor coordinate in WHCN layout. The linear block ID calculation uses the same pre-computed stride approach: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. The quantized integer value is loaded using the corresponding input buffer index: `qvalue = t_in[in_bufi]` where `in_bufi = tidx_to_bufi(out_tidx, t_in_strides)`. Dequantization applies the block-specific parameters: `value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id])` to reconstruct the original floating-point value. **Future Improvements**: Dynamic workgroup sizing based on block dimensions Differential Revision: [D78435552](https://our.internmc.facebook.com/intern/diff/D78435552/) cc SS-JIA manuelcandales cbilgin [ghstack-poisoned]
2 parents b2a8443 + dab15b5 commit d691080

File tree

2 files changed

+97
-58
lines changed

2 files changed

+97
-58
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 97 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -215,57 +215,20 @@ def fuse_into_linear_qcnw_node(
215215
#########################
216216

217217

218-
def matches_linear_qta8a_qga4w_pattern(
219-
program: ExportedProgram, node: torch.fx.Node
220-
) -> Optional[Tuple[int, int]]:
221-
"""
222-
Checks if the nodes surrounding a linear node matches the pattern for dynamic
223-
activation + grouped weight quantized linear (QTA8A_QGA4W).
224-
225-
This pattern involves:
226-
1. Dynamic quantization of input activations (8-bit)
227-
2. Grouped quantization of weights (4-bit with group size)
228-
229-
The expected pattern from Int8DynActInt4WeightQuantizer is:
230-
scale, zero_point = choose_qparams_affine(input)
231-
quantized_input = quantize_affine(input, scale, zero_point)
232-
dequantized_input = dequantize_affine(quantized_input, ...)
233-
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
234-
output = linear(dequantized_input, dequantized_weight)
235-
236-
If the pattern matches, return (group_size, weight_bits), otherwise None.
237-
"""
238-
if not utils.is_linear_node(node):
239-
return None
240-
241-
input_node = node.args[0]
242-
weight_node = node.args[1]
243-
244-
# Type checking - ensure we have torch.fx.Node objects
245-
if not isinstance(weight_node, torch.fx.Node):
246-
return None
247-
if not isinstance(input_node, torch.fx.Node):
248-
return None
249-
250-
# Check if input is dequantized with dequantize_affine (from dynamic quantization)
251-
if not (
252-
input_node.op == "call_function"
253-
and input_node.target is not None
254-
and hasattr(input_node.target, "__name__")
255-
and "dequantize_affine" in getattr(input_node.target, "__name__", "")
256-
):
257-
return None
218+
def _is_dequantize_affine_node(node: torch.fx.Node) -> bool:
219+
"""Check if a node is a dequantize_affine function call."""
220+
return (
221+
node.op == "call_function"
222+
and node.target is not None
223+
and hasattr(node.target, "__name__")
224+
and "dequantize_affine" in getattr(node.target, "__name__", "")
225+
)
258226

259-
# Check if weight is dequantized with dequantize_affine
260-
if not (
261-
weight_node.op == "call_function"
262-
and weight_node.target is not None
263-
and hasattr(weight_node.target, "__name__")
264-
and "dequantize_affine" in getattr(weight_node.target, "__name__", "")
265-
):
266-
return None
267227

268-
# Get the original quantized weight and quantization parameters
228+
def _validate_qta8a_qga4w_nodes(
229+
program: ExportedProgram, weight_node: torch.fx.Node
230+
) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]:
231+
"""Validate and extract weight quantization nodes for QTA8A_QGA4W pattern."""
269232
if len(weight_node.args) < 4:
270233
return None
271234

@@ -287,7 +250,16 @@ def matches_linear_qta8a_qga4w_pattern(
287250
if not is_param_node(program, weight_zeros):
288251
return None
289252

290-
# Get tensors to analyze the quantization scheme
253+
return orig_weight, weight_scales, weight_zeros
254+
255+
256+
def _validate_qta8a_qga4w_tensors(
257+
program: ExportedProgram,
258+
orig_weight: torch.fx.Node,
259+
weight_scales: torch.fx.Node,
260+
weight_zeros: torch.fx.Node,
261+
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
262+
"""Validate and extract weight tensors for QTA8A_QGA4W pattern."""
291263
orig_weight_tensor = get_param_tensor(program, orig_weight)
292264
weight_scales_tensor = get_param_tensor(program, weight_scales)
293265
weight_zeros_tensor = get_param_tensor(program, weight_zeros)
@@ -299,20 +271,24 @@ def matches_linear_qta8a_qga4w_pattern(
299271
if not isinstance(weight_zeros_tensor, torch.Tensor):
300272
return None
301273

302-
# Check if weight is quantized to 4 bits (values should be in [-8, 7] range)
274+
return orig_weight_tensor, weight_scales_tensor, weight_zeros_tensor
275+
276+
277+
def _validate_4bit_quantization(orig_weight_tensor: torch.Tensor) -> bool:
278+
"""Check if weight tensor is quantized to 4 bits."""
303279
quant_min = orig_weight_tensor.min().item()
304280
quant_max = orig_weight_tensor.max().item()
281+
return quant_min >= -8 and quant_max <= 7
305282

306-
if not (quant_min >= -8 and quant_max <= 7):
307-
return None
308-
309-
# Determine group size from the scales tensor shape
310-
# For grouped quantization, scales shape should be [out_features, in_features // group_size]
311-
out_features, in_features = orig_weight_tensor.shape
312283

284+
def _calculate_group_size(
285+
orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor
286+
) -> Optional[int]:
287+
"""Calculate and validate group size from tensor shapes."""
313288
if len(weight_scales_tensor.shape) != 2:
314289
return None
315290

291+
out_features, in_features = orig_weight_tensor.shape
316292
scales_out_features, num_groups = weight_scales_tensor.shape
317293

318294
if scales_out_features != out_features:
@@ -322,6 +298,70 @@ def matches_linear_qta8a_qga4w_pattern(
322298
if in_features % group_size != 0:
323299
return None
324300

301+
return group_size
302+
303+
304+
def matches_linear_qta8a_qga4w_pattern(
305+
program: ExportedProgram, node: torch.fx.Node
306+
) -> Optional[Tuple[int, int]]:
307+
"""
308+
Checks if the nodes surrounding a linear node matches the pattern for dynamic
309+
activation + grouped weight quantized linear (QTA8A_QGA4W).
310+
311+
This pattern involves:
312+
1. Dynamic quantization of input activations (8-bit)
313+
2. Grouped quantization of weights (4-bit with group size)
314+
315+
The expected pattern from Int8DynActInt4WeightQuantizer is:
316+
scale, zero_point = choose_qparams_affine(input)
317+
quantized_input = quantize_affine(input, scale, zero_point)
318+
dequantized_input = dequantize_affine(quantized_input, ...)
319+
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
320+
output = linear(dequantized_input, dequantized_weight)
321+
322+
If the pattern matches, return (group_size, weight_bits), otherwise None.
323+
"""
324+
if not utils.is_linear_node(node):
325+
return None
326+
327+
input_node = node.args[0]
328+
weight_node = node.args[1]
329+
330+
# Type checking - ensure we have torch.fx.Node objects
331+
if not isinstance(weight_node, torch.fx.Node):
332+
return None
333+
if not isinstance(input_node, torch.fx.Node):
334+
return None
335+
336+
# Check if input and weight are dequantized with dequantize_affine
337+
if not _is_dequantize_affine_node(input_node):
338+
return None
339+
if not _is_dequantize_affine_node(weight_node):
340+
return None
341+
342+
# Validate and extract weight quantization nodes
343+
weight_nodes = _validate_qta8a_qga4w_nodes(program, weight_node)
344+
if weight_nodes is None:
345+
return None
346+
orig_weight, weight_scales, weight_zeros = weight_nodes
347+
348+
# Validate and extract weight tensors
349+
weight_tensors = _validate_qta8a_qga4w_tensors(
350+
program, orig_weight, weight_scales, weight_zeros
351+
)
352+
if weight_tensors is None:
353+
return None
354+
orig_weight_tensor, weight_scales_tensor, weight_zeros_tensor = weight_tensors
355+
356+
# Check if weight is quantized to 4 bits
357+
if not _validate_4bit_quantization(orig_weight_tensor):
358+
return None
359+
360+
# Calculate and validate group size
361+
group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor)
362+
if group_size is None:
363+
return None
364+
325365
# Verify this is 4-bit grouped quantization
326366
weight_bits = 4
327367

backends/vulkan/custom_ops_lib.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def linear_qta8a_qga4w(
258258
weight_zeros: Per-group zero points for weights
259259
"""
260260
original_x_shape = x_quantized.shape
261-
batch_size = original_x_shape[0]
262261
feature_dim = original_x_shape[-1]
263262

264263
# Reshape for processing

0 commit comments

Comments
 (0)