Skip to content

Commit a0f18c8

Browse files
author
morelos
committed
[ET-VK] linear_qta8a_qga4w graph pass
# Changes * Introduce `linear_qta8a_qga4w` custom operator in `custom_ops_lib.py` to handle dynamic activation + grouped weight quantized linear operations * Add pattern matching and fusion logic in `FuseQuantizedOpsTransform` to detect and replace dequant + dequant + linear sequences with the new fused operator * Implement comprehensive test coverage in `test_vulkan_passes.py` for the QTA8A_QGA4W fusion pattern validation * Add 4-bit weight packing utilities and grouped quantization support for efficient memory usage # Motivation The existing quantization workflow in Vulkan backend processes dynamic activation + grouped weight quantized linear operations as separate quantize/dequantize/linear steps, which creates performance overhead through: * Multiple kernel dispatches instead of a single fused operation * Intermediate tensor allocations for dequantized weights and activations * Suboptimal memory bandwidth utilization The new `linear_qta8a_qga4w` operator fuses the entire sequence into a single operation that: * Directly processes 8-bit quantized activations with per-token scales/zero-points * Handles 4-bit grouped quantized weights with configurable group sizes * Eliminates intermediate dequantization steps by performing dequantization inline * Reduces memory footprint through packed 4-bit weight storage This aligns with the broader goal of optimizing quantized model inference in the Vulkan backend by leveraging graph-level transformations to improve computational efficiency while maintaining numerical accuracy. Differential Revision: [D78291269](https://our.internmc.facebook.com/intern/diff/D78291269/) [ghstack-poisoned]
1 parent 47aac0d commit a0f18c8

File tree

6 files changed

+371
-1
lines changed

6 files changed

+371
-1
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,19 +210,224 @@ def fuse_into_linear_qcnw_node(
210210
graph_module.graph.erase_node(dq_weight_node)
211211

212212

213+
#########################
214+
## linear_qta8a_qga4w ##
215+
#########################
216+
217+
218+
def matches_linear_qta8a_qga4w_pattern(
219+
program: ExportedProgram, node: torch.fx.Node
220+
) -> Optional[Tuple[int, int]]:
221+
"""
222+
Checks if the nodes surrounding a linear node matches the pattern for dynamic
223+
activation + grouped weight quantized linear (QTA8A_QGA4W).
224+
225+
This pattern involves:
226+
1. Dynamic quantization of input activations (8-bit)
227+
2. Grouped quantization of weights (4-bit with group size)
228+
229+
The expected pattern from Int8DynActInt4WeightQuantizer is:
230+
scale, zero_point = choose_qparams_affine(input)
231+
quantized_input = quantize_affine(input, scale, zero_point)
232+
dequantized_input = dequantize_affine(quantized_input, ...)
233+
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
234+
output = linear(dequantized_input, dequantized_weight)
235+
236+
If the pattern matches, return (group_size, weight_bits), otherwise None.
237+
"""
238+
if not utils.is_linear_node(node):
239+
return None
240+
241+
input_node = node.args[0]
242+
weight_node = node.args[1]
243+
244+
# Type checking - ensure we have torch.fx.Node objects
245+
if not isinstance(weight_node, torch.fx.Node):
246+
return None
247+
if not isinstance(input_node, torch.fx.Node):
248+
return None
249+
250+
# Check if input is dequantized with dequantize_affine (from dynamic quantization)
251+
if not (
252+
input_node.op == "call_function"
253+
and input_node.target is not None
254+
and hasattr(input_node.target, "__name__")
255+
and "dequantize_affine" in getattr(input_node.target, "__name__", "")
256+
):
257+
return None
258+
259+
# Check if weight is dequantized with dequantize_affine
260+
if not (
261+
weight_node.op == "call_function"
262+
and weight_node.target is not None
263+
and hasattr(weight_node.target, "__name__")
264+
and "dequantize_affine" in getattr(weight_node.target, "__name__", "")
265+
):
266+
return None
267+
268+
# Get the original quantized weight and quantization parameters
269+
if len(weight_node.args) < 4:
270+
return None
271+
272+
orig_weight = weight_node.args[0]
273+
weight_scales = weight_node.args[2]
274+
weight_zeros = weight_node.args[3]
275+
276+
# Type checking
277+
if not isinstance(orig_weight, torch.fx.Node):
278+
return None
279+
if not is_param_node(program, orig_weight):
280+
return None
281+
if not isinstance(weight_scales, torch.fx.Node):
282+
return None
283+
if not is_param_node(program, weight_scales):
284+
return None
285+
if not isinstance(weight_zeros, torch.fx.Node):
286+
return None
287+
if not is_param_node(program, weight_zeros):
288+
return None
289+
290+
# Get tensors to analyze the quantization scheme
291+
orig_weight_tensor = get_param_tensor(program, orig_weight)
292+
weight_scales_tensor = get_param_tensor(program, weight_scales)
293+
weight_zeros_tensor = get_param_tensor(program, weight_zeros)
294+
295+
if not isinstance(orig_weight_tensor, torch.Tensor):
296+
return None
297+
if not isinstance(weight_scales_tensor, torch.Tensor):
298+
return None
299+
if not isinstance(weight_zeros_tensor, torch.Tensor):
300+
return None
301+
302+
# Check if weight is quantized to 4 bits (values should be in [-8, 7] range)
303+
quant_min = orig_weight_tensor.min().item()
304+
quant_max = orig_weight_tensor.max().item()
305+
306+
if not (quant_min >= -8 and quant_max <= 7):
307+
return None
308+
309+
# Determine group size from the scales tensor shape
310+
# For grouped quantization, scales shape should be [out_features, in_features // group_size]
311+
out_features, in_features = orig_weight_tensor.shape
312+
313+
if len(weight_scales_tensor.shape) != 2:
314+
return None
315+
316+
scales_out_features, num_groups = weight_scales_tensor.shape
317+
318+
if scales_out_features != out_features:
319+
return None
320+
321+
group_size = in_features // num_groups
322+
if in_features % group_size != 0:
323+
return None
324+
325+
# Verify this is 4-bit grouped quantization
326+
weight_bits = 4
327+
328+
return group_size, weight_bits
329+
330+
331+
def fuse_into_linear_qta8a_qga4w_node(
332+
program: ExportedProgram,
333+
graph_module: torch.fx.GraphModule,
334+
linear_node: torch.fx.Node,
335+
group_size: int,
336+
weight_bits: int,
337+
) -> None:
338+
"""
339+
Fuse the dynamic activation + grouped weight quantized linear pattern into
340+
a single linear_qta8a_qga4w operator.
341+
342+
The pattern:
343+
dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
344+
dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
345+
output = linear(dequantized_input, dequantized_weight)
346+
347+
Becomes:
348+
output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
349+
weight, group_size, weight_scales, weight_zeros)
350+
"""
351+
dq_input_node = linear_node.args[0]
352+
dq_weight_node = linear_node.args[1]
353+
354+
assert isinstance(dq_input_node, torch.fx.Node)
355+
assert isinstance(dq_weight_node, torch.fx.Node)
356+
357+
# Get the quantized input and quantization parameters from the input dequantize_affine node
358+
# Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
359+
quantized_input = dq_input_node.args[0]
360+
input_scale = dq_input_node.args[2] # scale is the 3rd argument
361+
input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None
362+
363+
# Get the weight and its quantization parameters from dequantize_affine
364+
# Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
365+
orig_weight = dq_weight_node.args[0]
366+
weight_scales = dq_weight_node.args[2]
367+
weight_zeros = dq_weight_node.args[3]
368+
369+
# Pack the 4-bit weight tensor for efficient storage
370+
assert isinstance(orig_weight, torch.fx.Node)
371+
orig_weight_tensor = get_param_tensor(program, orig_weight)
372+
assert isinstance(orig_weight_tensor, torch.Tensor)
373+
packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor)
374+
utils.update_program_state_dict(
375+
program,
376+
orig_weight.name,
377+
packed_weight_tensor,
378+
)
379+
# Update the metadata to reflect the new packed shape
380+
orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8)
381+
382+
# Create the linear_qta8a_qga4w node
383+
with graph_module.graph.inserting_before(linear_node):
384+
linear_qta8a_qga4w_node = graph_module.graph.create_node(
385+
"call_function",
386+
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
387+
(
388+
quantized_input, # quantized input (int8)
389+
input_scale, # mat1_scale
390+
input_zero_point, # mat1_zero_point
391+
orig_weight, # mat2_data (packed 4-bit weights)
392+
group_size, # group_size (int)
393+
weight_scales, # weight_scales
394+
weight_zeros, # weight_zeros
395+
),
396+
)
397+
398+
# Replace the linear node with the new fused node
399+
linear_node.replace_all_uses_with(linear_qta8a_qga4w_node)
400+
401+
# Erase nodes in the correct order (users first, then dependencies)
402+
graph_module.graph.erase_node(linear_node)
403+
graph_module.graph.erase_node(dq_weight_node)
404+
graph_module.graph.erase_node(dq_input_node)
405+
406+
213407
class FuseQuantizedOpsTransform(ExportPass):
214408
def __init__(self, exported_program: ExportedProgram) -> None:
215409
super().__init__()
216410
self.program = exported_program
217411

218412
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219413
for node in graph_module.graph.nodes:
414+
# Check for linear_qcnw pattern (weight-only quantization)
220415
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
221416
if qcnw_details is not None:
222417
qcnw_method, qcnw_nbits = qcnw_details
223418
fuse_into_linear_qcnw_node(
224419
self.program, graph_module, node, qcnw_method, qcnw_nbits
225420
)
421+
continue
422+
423+
# Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
424+
qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node)
425+
if qta8a_qga4w_details is not None:
426+
group_size, weight_bits = qta8a_qga4w_details
427+
fuse_into_linear_qta8a_qga4w_node(
428+
self.program, graph_module, node, group_size, weight_bits
429+
)
430+
continue
226431

227432
graph_module.recompile()
228433
dead_code_elimination_pass(graph_module)

backends/vulkan/custom_ops_lib.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,96 @@ def linear_qcs4w(
231231
lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd")
232232
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)
233233

234+
########################
235+
## linear_qta8a_qga4w ##
236+
########################
237+
238+
239+
def linear_qta8a_qga4w(
240+
x_quantized: torch.Tensor,
241+
input_scale: torch.Tensor,
242+
input_zero_point: torch.Tensor,
243+
weights_4bit: torch.Tensor,
244+
group_size: int,
245+
weight_scales: torch.Tensor,
246+
weight_zeros: torch.Tensor,
247+
):
248+
"""
249+
Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).
250+
251+
Args:
252+
x_quantized: Already quantized input tensor (int8, per-token quantized)
253+
input_scale: Scale for per-token quantization of input (shape: [batch_size])
254+
input_zero_point: Zero point for per-token quantization of input (shape: [batch_size])
255+
weights_4bit: Packed 4-bit quantized weights
256+
group_size: Group size for weight quantization (int)
257+
weight_scales: Per-group scales for weights
258+
weight_zeros: Per-group zero points for weights
259+
"""
260+
original_x_shape = x_quantized.shape
261+
batch_size = original_x_shape[0]
262+
feature_dim = original_x_shape[-1]
263+
264+
# Reshape for processing
265+
x_quantized_2d = x_quantized.reshape(-1, feature_dim)
266+
267+
# Unpack 4-bit weights
268+
unpacked_weights_shape = weights_4bit.shape
269+
out_features = unpacked_weights_shape[0]
270+
in_features = unpacked_weights_shape[1]
271+
272+
weights_unpacked = torch.empty(
273+
(out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device
274+
)
275+
276+
weights_unpacked[:, ::2] = weights_4bit >> 4
277+
weights_unpacked[:, 1::2] = weights_4bit & 0x0F
278+
279+
# Convert to signed 4-bit range [-8, 7]
280+
weights_unpacked = torch.where(
281+
weights_unpacked > 7, weights_unpacked - 16, weights_unpacked
282+
)
283+
284+
# Dequantize weights using grouped quantization
285+
actual_in_features = in_features * 2
286+
num_groups = actual_in_features // group_size
287+
288+
# Reshape weights for grouped dequantization
289+
weights_grouped = weights_unpacked.view(out_features, num_groups, group_size)
290+
291+
# Expand scales and zeros to match grouped weights
292+
scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size)
293+
zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size)
294+
295+
# Dequantize: (quantized - zero_point) * scale
296+
dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded
297+
dq_weights = dq_weights_grouped.view(out_features, actual_in_features)
298+
299+
# Dequantize input (per-token)
300+
# For per-token quantization, each token (row) has its own scale and zero_point
301+
x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token(
302+
x_quantized_2d,
303+
input_scale,
304+
input_zero_point,
305+
-128,
306+
127,
307+
torch.int8,
308+
torch.float32,
309+
)
310+
311+
# Perform linear operation
312+
out = torch.nn.functional.linear(x_dequantized, dq_weights)
313+
out_shape = original_x_shape[:-1] + (out_features,)
314+
return out.reshape(out_shape)
315+
316+
317+
name = "linear_qta8a_qga4w"
318+
lib.define(
319+
f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor"
320+
)
321+
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
322+
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)
323+
234324
######################
235325
## apply_rotary_emb ##
236326
######################

backends/vulkan/op_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,12 @@ def register_int8_mm_op(features: OpFeatures):
487487
return features
488488

489489

490-
@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
490+
@update_features(
491+
[
492+
exir_ops.edge.et_vk.linear_weight_int4.default,
493+
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
494+
]
495+
)
491496
def register_int4_mm_op(features: OpFeatures):
492497
features.buffer_impl = True
493498
features.texture_impl = TextureImplFeatures(

backends/vulkan/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ python_unittest(
3434
"//executorch/backends/vulkan/_passes:vulkan_passes",
3535
"//executorch/backends/vulkan/quantizer:vulkan_quantizer",
3636
"//executorch/backends/vulkan:vulkan_preprocess",
37+
"//pytorch/ao:torchao", # @manual
3738
]
3839
)
3940

0 commit comments

Comments
 (0)