Skip to content

Commit 21ab2ec

Browse files
pytorchbotmorelosahmtoxGasoonjia
authored
[ET-VK] Creating get_symmetric_quantization_config (#12999)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12573 by @ahmtox ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/ahmtox/41/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/41/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/ahmtox/40/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/41/orig @diff-train-skip-merge --------- Co-authored-by: morelos <[email protected]> Co-authored-by: ahmtox <[email protected]> Co-authored-by: Gasoonjia <[email protected]>
1 parent da0c80a commit 21ab2ec

28 files changed

+3355
-730
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,19 +210,302 @@ def fuse_into_linear_qcnw_node(
210210
graph_module.graph.erase_node(dq_weight_node)
211211

212212

213+
#########################
214+
## linear_qta8a_qga4w ##
215+
#########################
216+
217+
218+
def _is_dequantize_affine_node(node: torch.fx.Node) -> bool:
219+
"""Check if a node is a dequantize_affine operation."""
220+
return (
221+
node.op == "call_function"
222+
and node.target is not None
223+
and hasattr(node.target, "__name__")
224+
and "dequantize_affine" in getattr(node.target, "__name__", "")
225+
)
226+
227+
228+
def _is_view_copy_node(node: torch.fx.Node) -> bool:
229+
"""Check if a node is a view_copy operation."""
230+
return (
231+
node.op == "call_function"
232+
and node.target is not None
233+
and hasattr(node.target, "__name__")
234+
and "view_copy" in getattr(node.target, "__name__", "")
235+
)
236+
237+
238+
def _validate_qta8a_qga4w_nodes(
239+
input_node: torch.fx.node.Argument, weight_node: torch.fx.node.Argument
240+
) -> Optional[torch.fx.Node]:
241+
"""
242+
Validate input and weight nodes for QTA8A_QGA4W pattern.
243+
Returns the actual input node (after handling view operations) or None if invalid.
244+
"""
245+
# Type checking - ensure we have torch.fx.Node objects
246+
if not isinstance(weight_node, torch.fx.Node) or not isinstance(
247+
input_node, torch.fx.Node
248+
):
249+
return None
250+
251+
# Input may be preprocessed with a view node
252+
actual_input_node = input_node
253+
if _is_view_copy_node(input_node):
254+
actual_input_node = input_node.args[0]
255+
if not isinstance(actual_input_node, torch.fx.Node):
256+
return None
257+
258+
# Check if input is dequantized with dequantize_affine (from dynamic quantization)
259+
if not _is_dequantize_affine_node(actual_input_node):
260+
return None
261+
262+
# Check if weight is dequantized with dequantize_affine
263+
if not _is_dequantize_affine_node(weight_node):
264+
return None
265+
266+
return actual_input_node
267+
268+
269+
def _extract_weight_params(
270+
program: ExportedProgram, weight_node: torch.fx.Node
271+
) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]:
272+
"""Extract and validate weight parameters from dequantize_affine node."""
273+
# Get the original quantized weight and quantization parameters
274+
if len(weight_node.args) < 4:
275+
return None
276+
277+
orig_weight = weight_node.args[0]
278+
weight_scales = weight_node.args[2]
279+
weight_zeros = weight_node.args[3]
280+
281+
# Type checking
282+
if not isinstance(orig_weight, torch.fx.Node) or not is_param_node(
283+
program, orig_weight
284+
):
285+
return None
286+
if not isinstance(weight_scales, torch.fx.Node) or not is_param_node(
287+
program, weight_scales
288+
):
289+
return None
290+
if not isinstance(weight_zeros, torch.fx.Node) or not is_param_node(
291+
program, weight_zeros
292+
):
293+
return None
294+
295+
return orig_weight, weight_scales, weight_zeros
296+
297+
298+
def _validate_4bit_quantization(weight_tensor: torch.Tensor) -> bool:
299+
"""Check if weight tensor is quantized to 4 bits (values in [-8, 7] range)."""
300+
quant_min = weight_tensor.min().item()
301+
quant_max = weight_tensor.max().item()
302+
return quant_min >= -8 and quant_max <= 7
303+
304+
305+
def _calculate_group_size(
306+
orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor
307+
) -> Optional[int]:
308+
"""Calculate and validate group size from weight and scales tensors."""
309+
out_features, in_features = orig_weight_tensor.shape
310+
311+
if len(weight_scales_tensor.shape) != 2:
312+
return None
313+
314+
scales_out_features, num_groups = weight_scales_tensor.shape
315+
316+
if scales_out_features != out_features:
317+
return None
318+
319+
group_size = in_features // num_groups
320+
if in_features % group_size != 0:
321+
return None
322+
323+
return group_size
324+
325+
326+
def matches_linear_qta8a_qga4w_pattern(
327+
program: ExportedProgram, node: torch.fx.Node
328+
) -> Optional[Tuple[int, int]]:
329+
"""
330+
Checks if the nodes surrounding a linear node matches the pattern for dynamic
331+
activation + grouped weight quantized linear (QTA8A_QGA4W).
332+
333+
This pattern involves:
334+
1. Dynamic quantization of input activations (8-bit)
335+
2. Grouped quantization of weights (4-bit with group size)
336+
337+
The expected pattern from Int8DynActInt4WeightQuantizer is:
338+
scale, zero_point = choose_qparams_affine(input)
339+
quantized_input = quantize_affine(input, scale, zero_point)
340+
dequantized_input = dequantize_affine(quantized_input, ...)
341+
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
342+
output = linear(dequantized_input, dequantized_weight)
343+
344+
If the pattern matches, return (group_size, weight_bits), otherwise None.
345+
"""
346+
if not utils.is_linear_node(node):
347+
return None
348+
349+
input_node = node.args[0]
350+
weight_node = node.args[1]
351+
352+
# Validate nodes and get actual input node
353+
actual_input_node = _validate_qta8a_qga4w_nodes(input_node, weight_node)
354+
if actual_input_node is None:
355+
return None
356+
357+
# Extract weight parameters
358+
if not isinstance(weight_node, torch.fx.Node):
359+
return None
360+
weight_params = _extract_weight_params(program, weight_node)
361+
if weight_params is None:
362+
return None
363+
364+
orig_weight, weight_scales, weight_zeros = weight_params
365+
366+
# Get tensors to analyze the quantization scheme
367+
orig_weight_tensor = get_param_tensor(program, orig_weight)
368+
weight_scales_tensor = get_param_tensor(program, weight_scales)
369+
weight_zeros_tensor = get_param_tensor(program, weight_zeros)
370+
371+
if not isinstance(orig_weight_tensor, torch.Tensor):
372+
return None
373+
if not isinstance(weight_scales_tensor, torch.Tensor):
374+
return None
375+
if not isinstance(weight_zeros_tensor, torch.Tensor):
376+
return None
377+
378+
# Check if weight is quantized to 4 bits
379+
if not _validate_4bit_quantization(orig_weight_tensor):
380+
return None
381+
382+
# Calculate group size
383+
group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor)
384+
if group_size is None:
385+
return None
386+
387+
# Verify this is 4-bit grouped quantization
388+
weight_bits = 4
389+
390+
return group_size, weight_bits
391+
392+
393+
def fuse_into_linear_qta8a_qga4w_node(
394+
program: ExportedProgram,
395+
graph_module: torch.fx.GraphModule,
396+
linear_node: torch.fx.Node,
397+
group_size: int,
398+
weight_bits: int,
399+
) -> None:
400+
"""
401+
Fuse the dynamic activation + grouped weight quantized linear pattern into
402+
a single linear_qta8a_qga4w operator.
403+
404+
The pattern:
405+
dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
406+
dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
407+
output = linear(dequantized_input, dequantized_weight)
408+
409+
Becomes:
410+
output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
411+
weight, group_size, weight_scales, weight_zeros)
412+
"""
413+
dq_input_node = linear_node.args[0]
414+
dq_weight_node = linear_node.args[1]
415+
416+
assert isinstance(dq_input_node, torch.fx.Node)
417+
418+
input_view_node = None
419+
# Input may be preprocessed with a view node
420+
if (
421+
dq_input_node.op == "call_function"
422+
and dq_input_node.target is not None
423+
and hasattr(dq_input_node.target, "__name__")
424+
and "view_copy" in getattr(dq_input_node.target, "__name__", "")
425+
):
426+
input_view_node = dq_input_node
427+
dq_input_node = dq_input_node.args[0]
428+
assert isinstance(dq_input_node, torch.fx.Node)
429+
430+
assert isinstance(dq_input_node, torch.fx.Node)
431+
assert isinstance(dq_weight_node, torch.fx.Node)
432+
433+
# Get the quantized input and quantization parameters from the input dequantize_affine node
434+
# Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
435+
quantized_input = dq_input_node.args[0]
436+
input_scale = dq_input_node.args[2] # scale is the 3rd argument
437+
input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None
438+
439+
# Get the weight and its quantization parameters from dequantize_affine
440+
# Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
441+
orig_weight = dq_weight_node.args[0]
442+
weight_scales = dq_weight_node.args[2]
443+
weight_zeros = dq_weight_node.args[3]
444+
445+
# Pack the 4-bit weight tensor for efficient storage
446+
assert isinstance(orig_weight, torch.fx.Node)
447+
orig_weight_tensor = get_param_tensor(program, orig_weight)
448+
assert isinstance(orig_weight_tensor, torch.Tensor)
449+
packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor)
450+
utils.update_program_state_dict(
451+
program,
452+
orig_weight.name,
453+
packed_weight_tensor,
454+
)
455+
# Update the metadata to reflect the new packed shape
456+
orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8)
457+
458+
# Create the linear_qta8a_qga4w node
459+
with graph_module.graph.inserting_before(linear_node):
460+
linear_qta8a_qga4w_node = graph_module.graph.create_node(
461+
"call_function",
462+
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
463+
(
464+
quantized_input, # quantized input (int8)
465+
input_scale, # mat1_scale
466+
input_zero_point, # mat1_zero_point
467+
orig_weight, # mat2_data (packed 4-bit weights)
468+
group_size, # group_size (int)
469+
weight_scales, # weight_scales
470+
weight_zeros, # weight_zeros
471+
),
472+
)
473+
474+
# Replace the linear node with the new fused node
475+
linear_node.replace_all_uses_with(linear_qta8a_qga4w_node)
476+
477+
# Erase nodes in the correct order (users first, then dependencies)
478+
graph_module.graph.erase_node(linear_node)
479+
if input_view_node is not None:
480+
graph_module.graph.erase_node(input_view_node)
481+
graph_module.graph.erase_node(dq_weight_node)
482+
graph_module.graph.erase_node(dq_input_node)
483+
484+
213485
class FuseQuantizedOpsTransform(ExportPass):
214486
def __init__(self, exported_program: ExportedProgram) -> None:
215487
super().__init__()
216488
self.program = exported_program
217489

218490
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219491
for node in graph_module.graph.nodes:
492+
# Check for linear_qcnw pattern (weight-only quantization)
220493
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
221494
if qcnw_details is not None:
222495
qcnw_method, qcnw_nbits = qcnw_details
223496
fuse_into_linear_qcnw_node(
224497
self.program, graph_module, node, qcnw_method, qcnw_nbits
225498
)
499+
continue
500+
501+
# Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
502+
qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node)
503+
if qta8a_qga4w_details is not None:
504+
group_size, weight_bits = qta8a_qga4w_details
505+
fuse_into_linear_qta8a_qga4w_node(
506+
self.program, graph_module, node, group_size, weight_bits
507+
)
508+
continue
226509

227510
graph_module.recompile()
228511
dead_code_elimination_pass(graph_module)

backends/vulkan/custom_ops_lib.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,95 @@ def linear_qcs4w(
231231
lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd")
232232
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)
233233

234+
########################
235+
## linear_qta8a_qga4w ##
236+
########################
237+
238+
239+
def linear_qta8a_qga4w(
240+
x_quantized: torch.Tensor,
241+
input_scale: torch.Tensor,
242+
input_zero_point: torch.Tensor,
243+
weights_4bit: torch.Tensor,
244+
group_size: int,
245+
weight_scales: torch.Tensor,
246+
weight_zeros: torch.Tensor,
247+
):
248+
"""
249+
Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).
250+
251+
Args:
252+
x_quantized: Already quantized input tensor (int8, per-token quantized)
253+
input_scale: Scale for per-token quantization of input (shape: [batch_size])
254+
input_zero_point: Zero point for per-token quantization of input (shape: [batch_size])
255+
weights_4bit: Packed 4-bit quantized weights
256+
group_size: Group size for weight quantization (int)
257+
weight_scales: Per-group scales for weights
258+
weight_zeros: Per-group zero points for weights
259+
"""
260+
original_x_shape = x_quantized.shape
261+
feature_dim = original_x_shape[-1]
262+
263+
# Reshape for processing
264+
x_quantized_2d = x_quantized.reshape(-1, feature_dim)
265+
266+
# Unpack 4-bit weights
267+
unpacked_weights_shape = weights_4bit.shape
268+
out_features = unpacked_weights_shape[0]
269+
in_features = unpacked_weights_shape[1]
270+
271+
weights_unpacked = torch.empty(
272+
(out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device
273+
)
274+
275+
weights_unpacked[:, ::2] = weights_4bit >> 4
276+
weights_unpacked[:, 1::2] = weights_4bit & 0x0F
277+
278+
# Convert to signed 4-bit range [-8, 7]
279+
weights_unpacked = torch.where(
280+
weights_unpacked > 7, weights_unpacked - 16, weights_unpacked
281+
)
282+
283+
# Dequantize weights using grouped quantization
284+
actual_in_features = in_features * 2
285+
num_groups = actual_in_features // group_size
286+
287+
# Reshape weights for grouped dequantization
288+
weights_grouped = weights_unpacked.view(out_features, num_groups, group_size)
289+
290+
# Expand scales and zeros to match grouped weights
291+
scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size)
292+
zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size)
293+
294+
# Dequantize: (quantized - zero_point) * scale
295+
dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded
296+
dq_weights = dq_weights_grouped.view(out_features, actual_in_features)
297+
298+
# Dequantize input (per-token)
299+
# For per-token quantization, each token (row) has its own scale and zero_point
300+
x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token(
301+
x_quantized_2d,
302+
input_scale,
303+
input_zero_point,
304+
-128,
305+
127,
306+
torch.int8,
307+
torch.float32,
308+
)
309+
310+
# Perform linear operation
311+
out = torch.nn.functional.linear(x_dequantized, dq_weights)
312+
out_shape = original_x_shape[:-1] + (out_features,)
313+
return out.reshape(out_shape)
314+
315+
316+
name = "linear_qta8a_qga4w"
317+
lib.define(
318+
f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor"
319+
)
320+
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
321+
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)
322+
234323
######################
235324
## apply_rotary_emb ##
236325
######################

0 commit comments

Comments
 (0)