Skip to content

Commit 8b67c74

Browse files
author
morelos
committed
Update on "[ET-VK] linear_qta8a_qga4w graph pass"
# Changes * Introduce `linear_qta8a_qga4w` custom operator in `custom_ops_lib.py` to handle dynamic activation + grouped weight quantized linear operations * Add pattern matching and fusion logic in `FuseQuantizedOpsTransform` to detect and replace dequant + dequant + linear sequences with the new fused operator * Implement comprehensive test coverage in `test_vulkan_passes.py` for the QTA8A_QGA4W fusion pattern validation * Add 4-bit weight packing utilities and grouped quantization support for efficient memory usage # Motivation The existing quantization workflow in Vulkan backend processes dynamic activation + grouped weight quantized linear operations as separate quantize/dequantize/linear steps, which creates performance overhead through: * Multiple kernel dispatches instead of a single fused operation * Intermediate tensor allocations for dequantized weights and activations * Suboptimal memory bandwidth utilization The new `linear_qta8a_qga4w` operator fuses the entire sequence into a single operation that: * Directly processes 8-bit quantized activations with per-token scales/zero-points * Handles 4-bit grouped quantized weights with configurable group sizes * Eliminates intermediate dequantization steps by performing dequantization inline * Reduces memory footprint through packed 4-bit weight storage This aligns with the broader goal of optimizing quantized model inference in the Vulkan backend by leveraging graph-level transformations to improve computational efficiency while maintaining numerical accuracy. Differential Revision: [D78291269](https://our.internmc.facebook.com/intern/diff/D78291269/) [ghstack-poisoned]
2 parents 1b7f43b + c5e7b1a commit 8b67c74

File tree

2 files changed

+97
-58
lines changed

2 files changed

+97
-58
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 97 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -215,57 +215,20 @@ def fuse_into_linear_qcnw_node(
215215
#########################
216216

217217

218-
def matches_linear_qta8a_qga4w_pattern(
219-
program: ExportedProgram, node: torch.fx.Node
220-
) -> Optional[Tuple[int, int]]:
221-
"""
222-
Checks if the nodes surrounding a linear node matches the pattern for dynamic
223-
activation + grouped weight quantized linear (QTA8A_QGA4W).
224-
225-
This pattern involves:
226-
1. Dynamic quantization of input activations (8-bit)
227-
2. Grouped quantization of weights (4-bit with group size)
228-
229-
The expected pattern from Int8DynActInt4WeightQuantizer is:
230-
scale, zero_point = choose_qparams_affine(input)
231-
quantized_input = quantize_affine(input, scale, zero_point)
232-
dequantized_input = dequantize_affine(quantized_input, ...)
233-
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
234-
output = linear(dequantized_input, dequantized_weight)
235-
236-
If the pattern matches, return (group_size, weight_bits), otherwise None.
237-
"""
238-
if not utils.is_linear_node(node):
239-
return None
240-
241-
input_node = node.args[0]
242-
weight_node = node.args[1]
243-
244-
# Type checking - ensure we have torch.fx.Node objects
245-
if not isinstance(weight_node, torch.fx.Node):
246-
return None
247-
if not isinstance(input_node, torch.fx.Node):
248-
return None
249-
250-
# Check if input is dequantized with dequantize_affine (from dynamic quantization)
251-
if not (
252-
input_node.op == "call_function"
253-
and input_node.target is not None
254-
and hasattr(input_node.target, "__name__")
255-
and "dequantize_affine" in getattr(input_node.target, "__name__", "")
256-
):
257-
return None
218+
def _is_dequantize_affine_node(node: torch.fx.Node) -> bool:
219+
"""Check if a node is a dequantize_affine function call."""
220+
return (
221+
node.op == "call_function"
222+
and node.target is not None
223+
and hasattr(node.target, "__name__")
224+
and "dequantize_affine" in getattr(node.target, "__name__", "")
225+
)
258226

259-
# Check if weight is dequantized with dequantize_affine
260-
if not (
261-
weight_node.op == "call_function"
262-
and weight_node.target is not None
263-
and hasattr(weight_node.target, "__name__")
264-
and "dequantize_affine" in getattr(weight_node.target, "__name__", "")
265-
):
266-
return None
267227

268-
# Get the original quantized weight and quantization parameters
228+
def _validate_qta8a_qga4w_nodes(
229+
program: ExportedProgram, weight_node: torch.fx.Node
230+
) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]:
231+
"""Validate and extract weight quantization nodes for QTA8A_QGA4W pattern."""
269232
if len(weight_node.args) < 4:
270233
return None
271234

@@ -287,7 +250,16 @@ def matches_linear_qta8a_qga4w_pattern(
287250
if not is_param_node(program, weight_zeros):
288251
return None
289252

290-
# Get tensors to analyze the quantization scheme
253+
return orig_weight, weight_scales, weight_zeros
254+
255+
256+
def _validate_qta8a_qga4w_tensors(
257+
program: ExportedProgram,
258+
orig_weight: torch.fx.Node,
259+
weight_scales: torch.fx.Node,
260+
weight_zeros: torch.fx.Node,
261+
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
262+
"""Validate and extract weight tensors for QTA8A_QGA4W pattern."""
291263
orig_weight_tensor = get_param_tensor(program, orig_weight)
292264
weight_scales_tensor = get_param_tensor(program, weight_scales)
293265
weight_zeros_tensor = get_param_tensor(program, weight_zeros)
@@ -299,20 +271,24 @@ def matches_linear_qta8a_qga4w_pattern(
299271
if not isinstance(weight_zeros_tensor, torch.Tensor):
300272
return None
301273

302-
# Check if weight is quantized to 4 bits (values should be in [-8, 7] range)
274+
return orig_weight_tensor, weight_scales_tensor, weight_zeros_tensor
275+
276+
277+
def _validate_4bit_quantization(orig_weight_tensor: torch.Tensor) -> bool:
278+
"""Check if weight tensor is quantized to 4 bits."""
303279
quant_min = orig_weight_tensor.min().item()
304280
quant_max = orig_weight_tensor.max().item()
281+
return quant_min >= -8 and quant_max <= 7
305282

306-
if not (quant_min >= -8 and quant_max <= 7):
307-
return None
308-
309-
# Determine group size from the scales tensor shape
310-
# For grouped quantization, scales shape should be [out_features, in_features // group_size]
311-
out_features, in_features = orig_weight_tensor.shape
312283

284+
def _calculate_group_size(
285+
orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor
286+
) -> Optional[int]:
287+
"""Calculate and validate group size from tensor shapes."""
313288
if len(weight_scales_tensor.shape) != 2:
314289
return None
315290

291+
out_features, in_features = orig_weight_tensor.shape
316292
scales_out_features, num_groups = weight_scales_tensor.shape
317293

318294
if scales_out_features != out_features:
@@ -322,6 +298,70 @@ def matches_linear_qta8a_qga4w_pattern(
322298
if in_features % group_size != 0:
323299
return None
324300

301+
return group_size
302+
303+
304+
def matches_linear_qta8a_qga4w_pattern(
305+
program: ExportedProgram, node: torch.fx.Node
306+
) -> Optional[Tuple[int, int]]:
307+
"""
308+
Checks if the nodes surrounding a linear node matches the pattern for dynamic
309+
activation + grouped weight quantized linear (QTA8A_QGA4W).
310+
311+
This pattern involves:
312+
1. Dynamic quantization of input activations (8-bit)
313+
2. Grouped quantization of weights (4-bit with group size)
314+
315+
The expected pattern from Int8DynActInt4WeightQuantizer is:
316+
scale, zero_point = choose_qparams_affine(input)
317+
quantized_input = quantize_affine(input, scale, zero_point)
318+
dequantized_input = dequantize_affine(quantized_input, ...)
319+
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
320+
output = linear(dequantized_input, dequantized_weight)
321+
322+
If the pattern matches, return (group_size, weight_bits), otherwise None.
323+
"""
324+
if not utils.is_linear_node(node):
325+
return None
326+
327+
input_node = node.args[0]
328+
weight_node = node.args[1]
329+
330+
# Type checking - ensure we have torch.fx.Node objects
331+
if not isinstance(weight_node, torch.fx.Node):
332+
return None
333+
if not isinstance(input_node, torch.fx.Node):
334+
return None
335+
336+
# Check if input and weight are dequantized with dequantize_affine
337+
if not _is_dequantize_affine_node(input_node):
338+
return None
339+
if not _is_dequantize_affine_node(weight_node):
340+
return None
341+
342+
# Validate and extract weight quantization nodes
343+
weight_nodes = _validate_qta8a_qga4w_nodes(program, weight_node)
344+
if weight_nodes is None:
345+
return None
346+
orig_weight, weight_scales, weight_zeros = weight_nodes
347+
348+
# Validate and extract weight tensors
349+
weight_tensors = _validate_qta8a_qga4w_tensors(
350+
program, orig_weight, weight_scales, weight_zeros
351+
)
352+
if weight_tensors is None:
353+
return None
354+
orig_weight_tensor, weight_scales_tensor, weight_zeros_tensor = weight_tensors
355+
356+
# Check if weight is quantized to 4 bits
357+
if not _validate_4bit_quantization(orig_weight_tensor):
358+
return None
359+
360+
# Calculate and validate group size
361+
group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor)
362+
if group_size is None:
363+
return None
364+
325365
# Verify this is 4-bit grouped quantization
326366
weight_bits = 4
327367

backends/vulkan/custom_ops_lib.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def linear_qta8a_qga4w(
258258
weight_zeros: Per-group zero points for weights
259259
"""
260260
original_x_shape = x_quantized.shape
261-
batch_size = original_x_shape[0]
262261
feature_dim = original_x_shape[-1]
263262

264263
# Reshape for processing

0 commit comments

Comments
 (0)