Skip to content

Commit 9c85266

Browse files
authored
[ET-VK][DCE] Remove redundant quantized linear implementations (#14198)
As title. With the recent support added for dynamically quantized + weight only quantized int4 and int8 linear, which are consolidated under `QuantizedLinear.cpp`, the operators implemented in `QuantizedLinearQGANW.cpp` and `QuantizedLinear_QTA8A_QGA4W.cpp` are no longer required. This diff removes all code related to the operators implemented in those files, namely: * `et_vk.linear_weight_int4.default` * `et_vk.linear_qta8a_qga4w.default` AOT export logic needed to support those ops are also removed. Differential Revision: [D82120824](https://our.internmc.facebook.com/intern/diff/D82120824/)
1 parent 63481e3 commit 9c85266

28 files changed

+3
-3174
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,6 @@ runtime.python_library(
3434
],
3535
)
3636

37-
runtime.python_library(
38-
name = "int4_weight_only_quantizer",
39-
srcs = [
40-
"int4_weight_only_quantizer.py",
41-
],
42-
visibility = [
43-
"//executorch/backends/...",
44-
],
45-
deps = [
46-
"//executorch/backends/vulkan:custom_ops_lib",
47-
"//pytorch/ao:torchao",
48-
]
49-
)
50-
5137
runtime.python_library(
5238
name = "squeeze_unsqueeze_inputs",
5339
srcs = [
@@ -161,7 +147,6 @@ runtime.python_library(
161147
":fuse_patterns",
162148
":fuse_quantized_ops",
163149
":insert_prepack_nodes",
164-
":int4_weight_only_quantizer",
165150
":remove_asserts",
166151
":remove_local_scalar_dense",
167152
":remove_redundant_ops",

backends/vulkan/_passes/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
FuseQuantizedOpsTransform,
1313
)
1414
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
15-
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
16-
VkInt4WeightOnlyQuantizer,
17-
)
1815
from executorch.backends.vulkan._passes.remove_asserts import (
1916
remove_asserts,
2017
RemoveAssertsTransform,
@@ -35,7 +32,6 @@
3532
"FusePatternsPass",
3633
"FuseQuantizedOpsTransform",
3734
"insert_prepack_nodes",
38-
"VkInt4WeightOnlyQuantizer",
3935
"remove_asserts",
4036
"RemoveAssertsTransform",
4137
"RemoveLocalScalarDenseOpsTransform",

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 0 additions & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -210,278 +210,6 @@ def fuse_into_linear_qcnw_node(
210210
graph_module.graph.erase_node(dq_weight_node)
211211

212212

213-
#########################
214-
## linear_qta8a_qga4w ##
215-
#########################
216-
217-
218-
def _is_dequantize_affine_node(node: torch.fx.Node) -> bool:
219-
"""Check if a node is a dequantize_affine operation."""
220-
return (
221-
node.op == "call_function"
222-
and node.target is not None
223-
and hasattr(node.target, "__name__")
224-
and "dequantize_affine" in getattr(node.target, "__name__", "")
225-
)
226-
227-
228-
def _is_view_copy_node(node: torch.fx.Node) -> bool:
229-
"""Check if a node is a view_copy operation."""
230-
return (
231-
node.op == "call_function"
232-
and node.target is not None
233-
and hasattr(node.target, "__name__")
234-
and "view_copy" in getattr(node.target, "__name__", "")
235-
)
236-
237-
238-
def _validate_qta8a_qga4w_nodes(
239-
input_node: torch.fx.node.Argument, weight_node: torch.fx.node.Argument
240-
) -> Optional[torch.fx.Node]:
241-
"""
242-
Validate input and weight nodes for QTA8A_QGA4W pattern.
243-
Returns the actual input node (after handling view operations) or None if invalid.
244-
"""
245-
# Type checking - ensure we have torch.fx.Node objects
246-
if not isinstance(weight_node, torch.fx.Node) or not isinstance(
247-
input_node, torch.fx.Node
248-
):
249-
return None
250-
251-
# Input may be preprocessed with a view node
252-
actual_input_node = input_node
253-
if _is_view_copy_node(input_node):
254-
actual_input_node = input_node.args[0]
255-
if not isinstance(actual_input_node, torch.fx.Node):
256-
return None
257-
258-
# Check if input is dequantized with dequantize_affine (from dynamic quantization)
259-
if not _is_dequantize_affine_node(actual_input_node):
260-
return None
261-
262-
# Check if weight is dequantized with dequantize_affine
263-
if not _is_dequantize_affine_node(weight_node):
264-
return None
265-
266-
return actual_input_node
267-
268-
269-
def _extract_weight_params(
270-
program: ExportedProgram, weight_node: torch.fx.Node
271-
) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]:
272-
"""Extract and validate weight parameters from dequantize_affine node."""
273-
# Get the original quantized weight and quantization parameters
274-
if len(weight_node.args) < 4:
275-
return None
276-
277-
orig_weight = weight_node.args[0]
278-
weight_scales = weight_node.args[2]
279-
weight_zeros = weight_node.args[3]
280-
281-
# Type checking
282-
if not isinstance(orig_weight, torch.fx.Node) or not is_param_node(
283-
program, orig_weight
284-
):
285-
return None
286-
if not isinstance(weight_scales, torch.fx.Node) or not is_param_node(
287-
program, weight_scales
288-
):
289-
return None
290-
if not isinstance(weight_zeros, torch.fx.Node) or not is_param_node(
291-
program, weight_zeros
292-
):
293-
return None
294-
295-
return orig_weight, weight_scales, weight_zeros
296-
297-
298-
def _validate_4bit_quantization(weight_tensor: torch.Tensor) -> bool:
299-
"""Check if weight tensor is quantized to 4 bits (values in [-8, 7] range)."""
300-
quant_min = weight_tensor.min().item()
301-
quant_max = weight_tensor.max().item()
302-
return quant_min >= -8 and quant_max <= 7
303-
304-
305-
def _calculate_group_size(
306-
orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor
307-
) -> Optional[int]:
308-
"""Calculate and validate group size from weight and scales tensors."""
309-
out_features, in_features = orig_weight_tensor.shape
310-
311-
if len(weight_scales_tensor.shape) != 2:
312-
return None
313-
314-
scales_out_features, num_groups = weight_scales_tensor.shape
315-
316-
if scales_out_features != out_features:
317-
return None
318-
319-
group_size = in_features // num_groups
320-
if in_features % group_size != 0:
321-
return None
322-
323-
return group_size
324-
325-
326-
def matches_linear_qta8a_qga4w_pattern(
327-
program: ExportedProgram, node: torch.fx.Node
328-
) -> Optional[Tuple[int, int]]:
329-
"""
330-
Checks if the nodes surrounding a linear node matches the pattern for dynamic
331-
activation + grouped weight quantized linear (QTA8A_QGA4W).
332-
333-
This pattern involves:
334-
1. Dynamic quantization of input activations (8-bit)
335-
2. Grouped quantization of weights (4-bit with group size)
336-
337-
The expected pattern from Int8DynActInt4WeightQuantizer is:
338-
scale, zero_point = choose_qparams_affine(input)
339-
quantized_input = quantize_affine(input, scale, zero_point)
340-
dequantized_input = dequantize_affine(quantized_input, ...)
341-
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
342-
output = linear(dequantized_input, dequantized_weight)
343-
344-
If the pattern matches, return (group_size, weight_bits), otherwise None.
345-
"""
346-
if not utils.is_linear_node(node):
347-
return None
348-
349-
input_node = node.args[0]
350-
weight_node = node.args[1]
351-
352-
# Validate nodes and get actual input node
353-
actual_input_node = _validate_qta8a_qga4w_nodes(input_node, weight_node)
354-
if actual_input_node is None:
355-
return None
356-
357-
# Extract weight parameters
358-
if not isinstance(weight_node, torch.fx.Node):
359-
return None
360-
weight_params = _extract_weight_params(program, weight_node)
361-
if weight_params is None:
362-
return None
363-
364-
orig_weight, weight_scales, weight_zeros = weight_params
365-
366-
# Get tensors to analyze the quantization scheme
367-
orig_weight_tensor = get_param_tensor(program, orig_weight)
368-
weight_scales_tensor = get_param_tensor(program, weight_scales)
369-
weight_zeros_tensor = get_param_tensor(program, weight_zeros)
370-
371-
if not isinstance(orig_weight_tensor, torch.Tensor):
372-
return None
373-
if not isinstance(weight_scales_tensor, torch.Tensor):
374-
return None
375-
if not isinstance(weight_zeros_tensor, torch.Tensor):
376-
return None
377-
378-
# Check if weight is quantized to 4 bits
379-
if not _validate_4bit_quantization(orig_weight_tensor):
380-
return None
381-
382-
# Calculate group size
383-
group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor)
384-
if group_size is None:
385-
return None
386-
387-
# Verify this is 4-bit grouped quantization
388-
weight_bits = 4
389-
390-
return group_size, weight_bits
391-
392-
393-
def fuse_into_linear_qta8a_qga4w_node(
394-
program: ExportedProgram,
395-
graph_module: torch.fx.GraphModule,
396-
linear_node: torch.fx.Node,
397-
group_size: int,
398-
weight_bits: int,
399-
) -> None:
400-
"""
401-
Fuse the dynamic activation + grouped weight quantized linear pattern into
402-
a single linear_qta8a_qga4w operator.
403-
404-
The pattern:
405-
dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
406-
dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
407-
output = linear(dequantized_input, dequantized_weight)
408-
409-
Becomes:
410-
output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
411-
weight, group_size, weight_scales, weight_zeros)
412-
"""
413-
dq_input_node = linear_node.args[0]
414-
dq_weight_node = linear_node.args[1]
415-
416-
assert isinstance(dq_input_node, torch.fx.Node)
417-
418-
input_view_node = None
419-
# Input may be preprocessed with a view node
420-
if (
421-
dq_input_node.op == "call_function"
422-
and dq_input_node.target is not None
423-
and hasattr(dq_input_node.target, "__name__")
424-
and "view_copy" in getattr(dq_input_node.target, "__name__", "")
425-
):
426-
input_view_node = dq_input_node
427-
dq_input_node = dq_input_node.args[0]
428-
assert isinstance(dq_input_node, torch.fx.Node)
429-
430-
assert isinstance(dq_input_node, torch.fx.Node)
431-
assert isinstance(dq_weight_node, torch.fx.Node)
432-
433-
# Get the quantized input and quantization parameters from the input dequantize_affine node
434-
# Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
435-
quantized_input = dq_input_node.args[0]
436-
input_scale = dq_input_node.args[2] # scale is the 3rd argument
437-
input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None
438-
439-
# Get the weight and its quantization parameters from dequantize_affine
440-
# Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
441-
orig_weight = dq_weight_node.args[0]
442-
weight_scales = dq_weight_node.args[2]
443-
weight_zeros = dq_weight_node.args[3]
444-
445-
# Pack the 4-bit weight tensor for efficient storage
446-
assert isinstance(orig_weight, torch.fx.Node)
447-
orig_weight_tensor = get_param_tensor(program, orig_weight)
448-
assert isinstance(orig_weight_tensor, torch.Tensor)
449-
packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor)
450-
utils.update_program_state_dict(
451-
program,
452-
orig_weight.name,
453-
packed_weight_tensor,
454-
)
455-
# Update the metadata to reflect the new packed shape
456-
orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8)
457-
458-
# Create the linear_qta8a_qga4w node
459-
with graph_module.graph.inserting_before(linear_node):
460-
linear_qta8a_qga4w_node = graph_module.graph.create_node(
461-
"call_function",
462-
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
463-
(
464-
quantized_input, # quantized input (int8)
465-
input_scale, # mat1_scale
466-
input_zero_point, # mat1_zero_point
467-
orig_weight, # mat2_data (packed 4-bit weights)
468-
group_size, # group_size (int)
469-
weight_scales, # weight_scales
470-
weight_zeros, # weight_zeros
471-
),
472-
)
473-
474-
# Replace the linear node with the new fused node
475-
linear_node.replace_all_uses_with(linear_qta8a_qga4w_node)
476-
477-
# Erase nodes in the correct order (users first, then dependencies)
478-
graph_module.graph.erase_node(linear_node)
479-
if input_view_node is not None:
480-
graph_module.graph.erase_node(input_view_node)
481-
graph_module.graph.erase_node(dq_weight_node)
482-
graph_module.graph.erase_node(dq_input_node)
483-
484-
485213
class FuseQuantizedOpsTransform(ExportPass):
486214
def __init__(self, exported_program: ExportedProgram) -> None:
487215
super().__init__()
@@ -498,15 +226,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
498226
)
499227
continue
500228

501-
# Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
502-
qta8a_qga4w_details = None
503-
if qta8a_qga4w_details is not None:
504-
group_size, weight_bits = qta8a_qga4w_details
505-
fuse_into_linear_qta8a_qga4w_node(
506-
self.program, graph_module, node, group_size, weight_bits
507-
)
508-
continue
509-
510229
graph_module.recompile()
511230
dead_code_elimination_pass(graph_module)
512231

0 commit comments

Comments
 (0)