diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index ccf15fd2c7f..2c4588ac43d 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from executorch.backends.vulkan._passes.fold_qdq import FoldQDQPass from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, @@ -30,6 +31,7 @@ from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ + "FoldQDQPass", "FusePatternsPass", "FuseQuantizedOpsTransform", "insert_prepack_nodes", diff --git a/backends/vulkan/_passes/fold_qdq.py b/backends/vulkan/_passes/fold_qdq.py new file mode 100644 index 00000000000..e28a1d13b39 --- /dev/null +++ b/backends/vulkan/_passes/fold_qdq.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.vulkan.utils as utils +import torch + +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class FoldQDQPass(ExportPass): + """ + Erase Q/DQ chain introduced by PT2E quantization workflow. It is assumed that all + valid quant op patterns have already been fused before this pass. + """ + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(FoldQDQPass, self).__init__() + self.edge_program = edge_program + + def call(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + # Criteria for a foldable Q/DQ node: + # - only one user (dequantize) + if utils.is_quant_node(node): + if len(node.users) > 1: + continue + + dq_node = None + for user in node.users: + if utils.is_dequant_node(user): + dq_node = user + + if dq_node is None: + continue + + original_node = node.args[0] + assert isinstance(original_node, torch.fx.Node) + dq_node.replace_all_uses_with(original_node) + + graph_module.recompile() + dead_code_elimination_pass(graph_module) + # Re-trace to validate everything is ok + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index bc61b44ce78..36f7f0ed982 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import executorch.backends.vulkan.patterns as vk_patterns import torch.library @@ -321,6 +323,152 @@ def linear_qta8a_qga4w( lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd") linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name) +################# +## qaqw_linear ## +################# + + +def linear_q8ta_q8csw( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + qweights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + bias: Optional[torch.Tensor] = None, +): + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + qweights = qweights.transpose(0, 1) + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + qweights, + weight_scales, + weight_zeros, + 0, + -127, + 127, + torch.int8, + ) + + # Perform linear operation + out = torch.nn.functional.linear(x, weights) + if bias is not None: + out = out + bias + + return out + + +name = "linear_q8ta_q8csw" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor qweight, + Tensor weight_sums, + Tensor weight_scales, + Tensor? bias = None) -> Tensor + """ +) +lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") +qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) + +################## +## conv2d_q8ta_q8csw ## +################## + + +def conv2d_q8ta_q8csw( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + qweights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + bias: Optional[torch.Tensor], + kernel_size: list, + stride: list, + padding: list, + dilation: list, + groups: int, +): + """ + Quantized convolution implementation that restores weight tensor to original 4D format, + dequantizes, and executes convolution. + + Args: + x: Input tensor + input_scale: Input quantization scale + input_zero_point: Input quantization zero point + qweights: Quantized weights in reshaped 2D format (IC * H * W, OC) + weight_sums: Pre-computed weight sums per output channel + weight_scales: Weight quantization scales per output channel + stride: Convolution stride + padding: Convolution padding + dilation: Convolution dilation + groups: Number of groups for grouped convolution + kernel_size: Kernel size [H, W] + bias: Optional bias tensor + """ + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + + # Restore weight tensor from 2D format (IC * H * W, OC) back to 4D format (OC, IC, H, W) + # First transpose to get (OC, IC * H * W) + qweights_transposed = qweights.transpose(0, 1) + + # Extract kernel dimensions from the provided kernel_size + H, W = kernel_size[0], kernel_size[1] + + # Calculate dimensions + OC = qweights_transposed.shape[0] + IC_H_W = qweights_transposed.shape[1] + IC = IC_H_W // (H * W) + + # Reshape to original 4D format (OC, IC, H, W) + qweights_4d = qweights_transposed.view(OC, IC, H, W) + print(qweights_4d.shape) + + # Dequantize weights + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + qweights_4d, + weight_scales, + weight_zeros, + 0, # axis=0 for output channel quantization + -127, + 127, + torch.int8, + ) + print(weights.shape) + print(x.shape) + + # Perform convolution + out = torch.nn.functional.conv2d( + x, weights, bias, stride, padding, dilation, groups + ) + + return out + + +name = "conv2d_q8ta_q8csw" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor qweight, + Tensor weight_sums, + Tensor weight_scales, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] dilation, + SymInt groups) -> Tensor + """ +) +lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd") +conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name) ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 55c36463b51..aba7545cdc2 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -318,6 +318,18 @@ def register_int8_mm_op(): ) +@update_features( + [ + exir_ops.edge.et_vk.linear_q8ta_q8csw.default, + ] +) +def register_qa_qw_linear(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_prepacking=True, + ) + + @update_features( [ exir_ops.edge.et_vk.linear_weight_int4.default, @@ -457,6 +469,32 @@ def register_convolution_op(): ) +@update_features( + [ + exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + ] +) +def register_quantized_conv_op(): + return OpFeatures( + inputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # kernel_size (non tensor) + utils.NO_STORAGE, # stride (non tensor) + utils.NO_STORAGE, # padding (non tensor) + utils.NO_STORAGE, # dilation (non tensor) + utils.NO_STORAGE, # groups (non tensor) + ], + supports_resize=False, + supports_prepacking=True, + ) + + @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 06db2a58f12..e5b2d0f7864 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -22,6 +22,8 @@ vulkan_supported_ops, ) +from executorch.backends.vulkan.patterns import PatternMatch + from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, @@ -41,7 +43,6 @@ from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase -from torch.fx.passes.utils.matcher_utils import InternalMatch # pyre-ignore ops_not_to_decompose = [ @@ -60,7 +61,7 @@ def __init__( require_dynamic_shape: bool = False, operator_blocklist: Optional[Set[OpKey]] = None, operator_allowlist: Optional[Set[OpKey]] = None, - fusable_subgraphs: Optional[List[InternalMatch]] = None, + fusable_subgraphs: Optional[List[PatternMatch]] = None, nn_module_blocklist: Optional[Set[str]] = None, nn_module_allowlist: Optional[Set[str]] = None, ) -> None: @@ -72,13 +73,13 @@ def __init__( operator_blocklist if operator_blocklist is not None else set() ) self.operator_allowlist = operator_allowlist - self.fusable_subgraphs: List[InternalMatch] = ( + self.fusable_subgraphs: List[PatternMatch] = ( fusable_subgraphs if fusable_subgraphs is not None else [] ) # Create a set of all nodes that are part of fusable subgraphs for quick lookup self.fusable_nodes: Set[torch.fx.Node] = set() for match in self.fusable_subgraphs: - self.fusable_nodes.update(match.nodes_map.values()) + self.fusable_nodes.update(match.all_nodes) self.nn_module_blocklist = nn_module_blocklist self.nn_module_allowlist = nn_module_allowlist diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index b8026f517e6..8790eda9fe2 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -6,6 +6,8 @@ from typing import List +import executorch.backends.vulkan.patterns.quantized_convolution # noqa + import executorch.backends.vulkan.patterns.quantized_linear # noqa import executorch.backends.vulkan.patterns.rope # noqa @@ -13,9 +15,13 @@ import torch from executorch.backends.vulkan.patterns.pattern_registry import ( + create_pattern_match_from_internal_match, CreateReplacementFn, + DetectorFn, fusable_patterns, GetGraphFn, + PatternMatch, + register_pattern_detector, register_pattern_graph, register_pattern_replacement, ) @@ -24,15 +30,18 @@ from executorch.exir import ExportedProgram -from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher __all__ = [ + "PatternMatch", "GetGraphFn", + "DetectorFn", "CreateReplacementFn", "RotaryEmbeddingPattern", "fusable_patterns", "register_pattern_graph", + "register_pattern_detector", "register_pattern_replacement", ] @@ -48,14 +57,22 @@ def all_fusable_graph_patterns() -> List[torch.fx.GraphModule]: def get_all_fusable_subgraphs( graph_module: torch.fx.GraphModule, -) -> List[InternalMatch]: +) -> List[PatternMatch]: fusable_subgraphs = [] fuse_patterns = all_fusable_graph_patterns() for pattern in fuse_patterns: sm = SubgraphMatcher(pattern.graph, ignore_literals=True) matches = list(sm.match(graph_module.graph)) - fusable_subgraphs.extend(matches) + for match in matches: + fusable_subgraphs.append(create_pattern_match_from_internal_match(match)) + + for node in graph_module.graph.nodes: + for entry in fusable_patterns.values(): + if entry.detector_fn is not None: + maybe_match = entry.detector_fn(node) + if maybe_match is not None: + fusable_subgraphs.append(maybe_match) return fusable_subgraphs @@ -73,7 +90,8 @@ def create_replacement_for_pattern( matches = list(sm.match(graph_module.graph)) for partition_to_replace in matches: - create_replacement_func(ep, graph_module, partition_to_replace) + pattern = create_pattern_match_from_internal_match(partition_to_replace) + create_replacement_func(ep, graph_module, pattern) total_replaced += 1 # Remove dead code so they won't be matched again graph_module.graph.eliminate_dead_code() @@ -87,6 +105,7 @@ def replace_all_fusable_subgraphs( ) -> int: total_replaced = 0 + # Handle patterns identified with SubgraphMatcher for entry in fusable_patterns.values(): if entry.get_graphs_fn is not None and entry.create_replacement_fn is not None: total_replaced += create_replacement_for_pattern( @@ -97,4 +116,17 @@ def replace_all_fusable_subgraphs( entry.create_replacement_fn, ) + # Handle patterns identified with custom detector function + for node in graph_module.graph.nodes: + for entry in fusable_patterns.values(): + if ( + entry.detector_fn is not None + and entry.create_replacement_fn is not None + ): + maybe_match = entry.detector_fn(node) + if maybe_match is not None: + entry.create_replacement_fn(ep, graph_module, maybe_match) + total_replaced += 1 + + graph_module.graph.eliminate_dead_code() return total_replaced diff --git a/backends/vulkan/patterns/pattern_registry.py b/backends/vulkan/patterns/pattern_registry.py index 37fa0bcca8c..9a906cd8770 100644 --- a/backends/vulkan/patterns/pattern_registry.py +++ b/backends/vulkan/patterns/pattern_registry.py @@ -13,22 +13,65 @@ from torch.fx.passes.utils.matcher_utils import InternalMatch GetGraphFn = Callable[[], List[torch.fx.GraphModule]] + + +class PatternMatch: + __slots__ = ("input_nodes", "output_nodes", "all_nodes", "anchor_node") + """ + The design of this class is based on InternalMatch from + torch.fx.passes.utils.matcher_utils. It represents nodes in a graph that + match a particular pattern. + + The reason to not use InternalMatch directly is to enable more (i.e. custom) + methods to detect and represent matches other than through SubgraphMatcher. + """ + + def __init__( + self, + input_nodes: List[torch.fx.Node], + output_nodes: List[torch.fx.Node], + all_nodes: List[torch.fx.Node], + anchor_node: Optional[torch.fx.Node] = None, + ): + self.input_nodes = input_nodes + self.output_nodes = output_nodes + self.all_nodes = all_nodes + self.anchor_node = anchor_node + + +def create_pattern_match_from_internal_match( + internal_match: InternalMatch, +) -> PatternMatch: + return PatternMatch( + internal_match.placeholder_nodes, + internal_match.returning_nodes, + list(internal_match.nodes_map.values()), + ) + + CreateReplacementFn = Callable[ - [ExportedProgram, torch.fx.GraphModule, InternalMatch], None + [ExportedProgram, torch.fx.GraphModule, PatternMatch], None ] +DetectorFn = Callable[[torch.fx.Node], Optional[PatternMatch]] + + class PatternEntry: def __init__( self, get_graphs_fn: Optional[GetGraphFn] = None, + detector_fn: Optional[DetectorFn] = None, create_replacement_fn: Optional[CreateReplacementFn] = None, ): self.get_graphs_fn = get_graphs_fn + self.detector_fn = detector_fn self.create_replacement_fn = create_replacement_fn def is_valid(self): - return self.get_graphs_fn is not None and self.create_replacement_fn is not None + return ( + self.get_graphs_fn is not None or self.detector_fn is not None + ) and self.create_replacement_fn is not None fusable_patterns: Dict[str, PatternEntry] = {} @@ -39,7 +82,24 @@ def decorator(fn: GetGraphFn): if pattern_name not in fusable_patterns: fusable_patterns[pattern_name] = PatternEntry() + # Cannot define both get_graphs_fn and detector_fn + assert fusable_patterns[pattern_name].detector_fn is None fusable_patterns[pattern_name].get_graphs_fn = fn + + return fn + + return decorator + + +def register_pattern_detector(pattern_name: str): + def decorator(fn: DetectorFn): + if pattern_name not in fusable_patterns: + fusable_patterns[pattern_name] = PatternEntry() + + # Cannot define both get_graphs_fn and detector_fn + assert fusable_patterns[pattern_name].get_graphs_fn is None + fusable_patterns[pattern_name].detector_fn = fn + return fn return decorator diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py new file mode 100644 index 00000000000..ad40c3dedb4 --- /dev/null +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from backends.vulkan.utils import trace_args_until_placeholder + +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, +) + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.export.graph_signature import InputKind + + +class QuantizedConvolutionMatch(PatternMatch): + def __init__(self, conv_node: torch.fx.Node) -> None: + self.anchor_node = conv_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # Extract convolution parameters + self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1] + self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0] + self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1] + self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1 + + const_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[1] + ) + + # weight is not a constant tensor - no match + if const_node is None: + return + + dequantize_weight_node = None + # Search for a dequantize node in the arg chain of weight + for node in arg_chain: + if isinstance(node, torch.fx.Node) and utils.is_dequant_node(node): + dequantize_weight_node = node + # weight is not quantized - no match + if dequantize_weight_node is None: + return + + self.weight_node = const_node + self.dequantize_weight_node = dequantize_weight_node + self.all_nodes.extend(arg_chain) + + # Identify weight quantization parameter nodes + self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[1] + ) + assert self.weight_scales_node is not None + self.all_nodes.extend(arg_chain) + + self.weight_zeros_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[2] + ) + assert self.weight_zeros_node is not None + self.all_nodes.extend(arg_chain) + + # Identify output node + self.output_node = self.anchor_node + + # Identify bias node, if applicable + self.bias_node = None + if len(self.anchor_node.args) > 2 and self.anchor_node.args[2] is not None: + self.bias_node, arg_chain = trace_args_until_placeholder( + self.anchor_node.args[2] + ) + if self.bias_node is not None: + self.all_nodes.extend(arg_chain) + + # Identify input node + self.fp_input_node, self.quantize_input_node, dq_node = ( + utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) + ) + assert self.fp_input_node is not None + self.all_nodes.append(self.fp_input_node) + assert self.quantize_input_node is not None + assert dq_node is not None + + self.input_scales_node = self.quantize_input_node.args[1] + self.input_zeros_node = self.quantize_input_node.args[2] + + self.all_nodes.extend( + [ + self.quantize_input_node, + dq_node, + ] + ) + + self.match_found = True + + +convolution_anchor_nodes = { + exir_ops.edge.aten.conv2d.default, + exir_ops.edge.aten.convolution.default, +} + + +@register_pattern_detector("quantized_convolution") +def find_quantized_convolution_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedConvolutionMatch]: + if node.target not in convolution_anchor_nodes: + return None + + matched_pattern = QuantizedConvolutionMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_convolution") +def make_conv2d_q8ta_q8csw_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedConvolutionMatch, +): + weight_tensor = get_param_tensor(ep, match.weight_node) + assert weight_tensor is not None + + assert match.weight_scales_node is not None + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + + assert match.weight_zeros_node is not None + weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node) + assert weight_zeros_tensor is not None + + # Reshape weight tensor from (OC, IC, H, W) to (IC * H * W, OC) for matrix multiplication + # This prepares the weights for Im2Col-based convolution computation + OC, IC, H, W = weight_tensor.shape + print("weight_tensor:\n", weight_tensor) + + weight_tensor_reshaped = ( + weight_tensor.permute(2, 3, 1, 0).contiguous().view(IC * H * W, OC) + ) + print("weight_tensor:\n", weight_tensor_reshaped) + utils.update_program_state_dict(ep, match.weight_node.name, weight_tensor_reshaped) + # Need to make sure the fake tensor matches the updated tensor's properties + match.weight_node.meta["val"] = ( + match.weight_node.meta["val"] + .permute(1, 2, 3, 0) + .contiguous() + .view(IC * H * W, OC) + ) + + first_graph_node = list(graph_module.graph.nodes)[0] + with graph_module.graph.inserting_before(first_graph_node): + qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node) + # Pre-compute the weight sums which are needed to apply activation zero point + # when using integer accumulation. For the reshaped 2D weight matrix (IC * H * W, OC), + # sum over dimension 0 to get sums per output channel + sum_per_output_channel = ( + weight_tensor_reshaped.sum(dim=0).to(torch.float).contiguous() + ) + sums_name = qweight_tensor_name + "_sums" + weight_sums_node = create_constant_placeholder( + exp_program=ep, + graph=graph_module.graph, + kind=InputKind.CONSTANT_TENSOR, + name=sums_name, + data=sum_per_output_channel, + ) + + with graph_module.graph.inserting_before(match.output_node): + qconv_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + args=( + match.fp_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + match.bias_node, # Add bias after weight_scales + [H, W], # Pass kernel size information before stride + match.stride, + match.padding, + match.dilation, + match.groups, + ), + ) + + qconv_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(qconv_node) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 34476adeeb4..c45d8bbf155 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -4,131 +4,145 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from functools import lru_cache -from typing import Callable, List, Optional +from typing import Optional import executorch.backends.vulkan.utils as utils import torch import torch.nn.functional as F -from executorch.backends.transforms.utils import get_param_tensor, is_param_node +from backends.vulkan.utils import trace_args_until_placeholder + +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, +) from executorch.backends.vulkan.patterns.pattern_registry import ( - register_pattern_graph, + PatternMatch, + register_pattern_detector, register_pattern_replacement, ) -from executorch.exir import EdgeCompileConfig, ExportedProgram, to_edge +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from torch.export import export -from torch.fx.passes.utils.matcher_utils import InternalMatch +from torch.export.graph_signature import InputKind -from torchao.quantization.granularity import PerGroup -from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ -from torchao.utils import unwrap_tensor_subclass +class QuantizedLinearMatch(PatternMatch): + def __init__(self, mm_node: torch.fx.Node) -> None: + self.anchor_node = mm_node + self.match_found = False + self.all_nodes = [self.anchor_node] -class TorchAOWeightOnlyQuantizedLinearPattern(torch.nn.Module): - """ - Quantized linear pattern produced when quantizing linear layers using - `torchao.quantization.quant_api.quantize_()` with IntxWeightOnlyConfig. - """ + const_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[1] + ) + + # mat2 is not a constant tensor - no match + if const_node is None: + return + + dequantize_weight_node = None + # Search for a dequantize node in the arg chain of weight + for node in arg_chain: + if isinstance(node, torch.fx.Node) and utils.is_dequant_node(node): + dequantize_weight_node = node + # weight is not quantized - no match + if dequantize_weight_node is None: + return + + self.weight_node = const_node + self.dequantize_weight_node = dequantize_weight_node + self.all_nodes.extend(arg_chain) + + # By default, assume dequant node is from quantized_decomposed namespace + scales_arg_idx = 1 + zeros_arg_idx = 2 + # torchao dequantize has a different function schema than quantized_decomposed + if ( + self.dequantize_weight_node.target + == exir_ops.edge.torchao.dequantize_affine.default + ): + scales_arg_idx = 2 + zeros_arg_idx = 3 + + # Identify weight quantization parameter nodes + self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[scales_arg_idx] + ) + assert self.weight_scales_node is not None + self.all_nodes.extend(arg_chain) - def __init__( - self, - in_features: int = 512, - out_features: int = 256, - bias: bool = False, - group_size: int = 64, - weight_bits: int = 4, - granularity_class: Optional[Callable] = None, - ) -> None: - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias=bias) - self.group_size = group_size - self.weight_bits = weight_bits - - if self.weight_bits == 4: - # pyre-ignore[16] - self.weight_dtype = torch.int4 - else: - self.weight_dtype = torch.int8 - - if granularity_class is not None: - self.quant_granularity = granularity_class(self.group_size) - else: - self.quant_granularity = PerGroup(self.group_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - def apply_quantization(self): - q_config = IntxWeightOnlyConfig( - weight_dtype=self.weight_dtype, - granularity=self.quant_granularity, + self.weight_zeros_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[zeros_arg_idx] ) - quantize_(self, q_config) - unwrap_tensor_subclass(self) - return self - - -@lru_cache(maxsize=None) -@register_pattern_graph("torchao_wo_quantized_linear") -def get_torchao_wo_quantized_linear_graphs() -> List[torch.fx.GraphModule]: - graphs = [] - - # Different configurations to test - configs = [ - # gemv pattern - (1, 1, 128, 128, False, 64, 4, PerGroup), - # gemm pattern - (1, 8, 128, 128, False, 64, 4, PerGroup), - ] - - for ( - batch_size, - seq_len, - in_features, - out_features, - bias, - group_size, - weight_bits, - granularity_class, - ) in configs: - for dtype in [torch.float32]: - xs = [] - xs.append(torch.randn(batch_size, seq_len, in_features, dtype=dtype)) - if batch_size == 1: - xs.append(torch.randn(seq_len, in_features, dtype=dtype)) - - for x in xs: - # Create and quantize the pattern - pattern = TorchAOWeightOnlyQuantizedLinearPattern( - in_features=in_features, - out_features=out_features, - bias=bias, - group_size=group_size, - weight_bits=weight_bits, - granularity_class=granularity_class, - ) - - # Apply quantization - pattern = pattern.apply_quantization() - - # Export the quantized pattern - edge = to_edge( - export( - pattern, - (x,), - ), - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - gm = edge.exported_program().graph_module - graphs.append(gm) - - return graphs + assert self.weight_zeros_node is not None + self.all_nodes.extend(arg_chain) + + # Identify output node + self.output_node = self.anchor_node + + # Identify input node + self.fp_input_node, self.quantize_input_node, dq_node = ( + utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) + ) + assert self.fp_input_node is not None + self.all_nodes.append(self.fp_input_node) + + # Identify bias node, if applicable + self.bias_node = None + if self.anchor_node.target == exir_ops.edge.aten.addmm.default: + self.bias_node, arg_chain = trace_args_until_placeholder( + self.anchor_node.args[2] + ) + assert self.bias_node is not None + self.all_nodes.extend(arg_chain) + + # If input is not quantized, then we are done + if self.quantize_input_node is None: + self.match_found = True + return + + self.input_scales_node = self.quantize_input_node.args[1] + self.input_zeros_node = self.quantize_input_node.args[2] + + assert dq_node is not None + self.all_nodes.extend( + [ + self.quantize_input_node, + dq_node, + ] + ) + + self.match_found = True + + +linear_anchor_nodes = { + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.addmm.default, +} + + +@register_pattern_detector("quantized_linear") +def find_quantized_linear_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedLinearMatch]: + if node.target not in linear_anchor_nodes: + return None + + matched_pattern = QuantizedLinearMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Constant tensor manipulation +## def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: @@ -192,117 +206,132 @@ def make_combined_scales_and_zeros_tensor( return torch.cat((scales_reshaped, zeros_scaled), dim=2) -def identify_wo_quantized_linear_io_nodes( # noqa: C901 - ep: ExportedProgram, - graph_module: torch.fx.GraphModule, - match: InternalMatch, -) -> Optional[List[torch.fx.Node]]: - dequant_node = None - # First, find the dequant node - for node in match.nodes_map.values(): - if utils.is_dequant_node(node): - dequant_node = node - break - - if dequant_node is None: - return None - - quantized_weight = dequant_node.args[0] - quant_scales = dequant_node.args[2] - quant_zeros = dequant_node.args[3] +## +## Pattern Replacement +## - if not isinstance(quantized_weight, torch.fx.Node) or not is_param_node( - ep, quantized_weight - ): - return None - if not isinstance(quant_scales, torch.fx.Node) or not is_param_node( - ep, quant_scales - ): - return None - if not isinstance(quant_zeros, torch.fx.Node) or not is_param_node(ep, quant_zeros): - return None - - input_nodes = match.placeholder_nodes - if len(input_nodes) != 4: - return None - in_tensor_node = None - for node in input_nodes: - if node not in dequant_node.args: - in_tensor_node = node - break - - if in_tensor_node is None: - return None - - output_nodes = match.returning_nodes - - if len(output_nodes) != 1: - return None - - out_tensor_node = output_nodes[0] - if not isinstance(out_tensor_node, torch.fx.Node): - return None - - return [ - in_tensor_node, - quantized_weight, - quant_scales, - quant_zeros, - out_tensor_node, - ] - - -# wo = "weight only" -@register_pattern_replacement("torchao_wo_quantized_linear") -def create_wo_quantized_linear_custom_op( +def make_linear_q4ga_op( ep: ExportedProgram, graph_module: torch.fx.GraphModule, - match: InternalMatch, + match: QuantizedLinearMatch, ): - io_nodes = identify_wo_quantized_linear_io_nodes(ep, graph_module, match) - if io_nodes is None: - return + weight_tensor = get_param_tensor(ep, match.weight_node) + assert weight_tensor is not None - assert len(io_nodes) == 5 - in_tensor, quantized_weight, quant_scales, quant_zeros, out_tensor = io_nodes + assert match.weight_scales_node is not None + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None - quantized_weight_tensor = get_param_tensor(ep, quantized_weight) - if not isinstance(quantized_weight_tensor, torch.Tensor): - return - packed_quantized_weight_tensor = pack_4bit_weight_tensor(quantized_weight_tensor) + assert match.weight_zeros_node is not None + weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node) + assert weight_zeros_tensor is not None + + packed_quantized_weight_tensor = pack_4bit_weight_tensor(weight_tensor) utils.update_program_state_dict( - ep, quantized_weight.name, packed_quantized_weight_tensor + ep, match.weight_node.name, packed_quantized_weight_tensor + ) + # Need to make sure corresponding FakeTensor has same size + match.weight_node.meta["val"] = match.weight_node.meta["val"][:, ::2].to( + torch.uint8 ) - quantized_weight.meta["val"] = quantized_weight.meta["val"][:, ::2].to(torch.uint8) - - quant_scales_tensor = get_param_tensor(ep, quant_scales) - quant_zeros_tensor = get_param_tensor(ep, quant_zeros) - - assert quantized_weight_tensor is not None - assert quant_scales_tensor is not None - assert quant_zeros_tensor is not None - group_size = quantized_weight_tensor.shape[1] // quant_scales_tensor.shape[1] + group_size = weight_tensor.shape[1] // weight_scales_tensor.shape[1] combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor( - quant_scales_tensor, quant_zeros_tensor + weight_scales_tensor, weight_zeros_tensor ) - combined_scales_zeros_name = f"{quantized_weight.name}_scales_zeros" + combined_scales_zeros_name = f"{match.weight_node.name}_scales_zeros" graph_module.register_parameter( combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor) ) - with graph_module.graph.inserting_before(out_tensor): + with graph_module.graph.inserting_before(match.output_node): combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name) - wo_qlinear = graph_module.graph.create_node( + linear_q4ga_node = graph_module.graph.create_node( "call_function", exir_ops.edge.et_vk.linear_weight_int4.default, - args=(in_tensor, quantized_weight, group_size, combined_scales_zeros, 1), + args=( + match.fp_input_node, + match.weight_node, + group_size, + combined_scales_zeros, + 1, + ), ) - if hasattr(out_tensor, "meta") and "val" in out_tensor.meta: - wo_qlinear.meta["val"] = out_tensor.meta["val"] + linear_q4ga_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(linear_q4ga_node) + - out_tensor.replace_all_uses_with(wo_qlinear) +def make_linear_q8ta_q8csw_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedLinearMatch, +): + weight_tensor = get_param_tensor(ep, match.weight_node) + assert weight_tensor is not None + + assert match.weight_scales_node is not None + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + + assert match.weight_zeros_node is not None + weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node) + assert weight_zeros_tensor is not None + + # Transpose the weight matrix + weight_transposed = weight_tensor.transpose(0, 1).contiguous() + utils.update_program_state_dict(ep, match.weight_node.name, weight_transposed) + weight_tensor = weight_transposed + # Need to make sure the fake tensor matches the updated tensor's properties + match.weight_node.meta["val"] = match.weight_node.meta["val"].transpose(0, 1) + + first_graph_node = list(graph_module.graph.nodes)[0] + with graph_module.graph.inserting_before(first_graph_node): + qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node) + # Pre-compute the weight sums which are needed to apply activation zero point + # when using integer accumulation. + sum_per_output_channel = ( + weight_transposed.sum(dim=0).to(torch.float).contiguous() + ) + sums_name = qweight_tensor_name + "_sums" + weight_sums_node = create_constant_placeholder( + exp_program=ep, + graph=graph_module.graph, + kind=InputKind.CONSTANT_TENSOR, + name=sums_name, + data=sum_per_output_channel, + ) + + with graph_module.graph.inserting_before(match.output_node): + qlinear_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.linear_q8ta_q8csw.default, + args=( + match.fp_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + ), + ) + + qlinear_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(qlinear_node) + print("done...") + + +@register_pattern_replacement("quantized_linear") +def replace_quantized_linear_patterns( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedLinearMatch, +): + if match.quantize_input_node is None: + make_linear_q4ga_op(ep, graph_module, match) + else: + print("hit...") + make_linear_q8ta_q8csw_custom_op(ep, graph_module, match) diff --git a/backends/vulkan/patterns/rope.py b/backends/vulkan/patterns/rope.py index e0c2e4c5501..b174224ab78 100644 --- a/backends/vulkan/patterns/rope.py +++ b/backends/vulkan/patterns/rope.py @@ -12,6 +12,7 @@ import torch from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, register_pattern_graph, register_pattern_replacement, ) @@ -20,7 +21,6 @@ from executorch.exir.dialects._ops import ops as exir_ops from torch.export import export -from torch.fx.passes.utils.matcher_utils import InternalMatch class RotaryEmbeddingPattern(torch.nn.Module): @@ -111,16 +111,16 @@ def get_rope_graphs() -> List[torch.fx.GraphModule]: def identify_rotary_emb_io_nodes( ep: ExportedProgram, graph_module: torch.fx.GraphModule, - match: InternalMatch, + match: PatternMatch, ) -> Optional[List[torch.fx.Node]]: - # Get the input placeholders (xq, xk, freqs_cos, freqs_sin) - placeholder_nodes = match.placeholder_nodes - if len(placeholder_nodes) != 4: + # Get the input inputs (xq, xk, freqs_cos, freqs_sin) + input_nodes = match.input_nodes + if len(input_nodes) != 4: return None - xq, xk, freqs_cos, freqs_sin = placeholder_nodes + xq, xk, freqs_cos, freqs_sin = input_nodes - output_nodes = match.returning_nodes + output_nodes = match.output_nodes if len(output_nodes) != 2: return None @@ -133,7 +133,7 @@ def identify_rotary_emb_io_nodes( def create_rotary_emb_custom_op( ep: ExportedProgram, graph_module: torch.fx.GraphModule, - match: InternalMatch, + match: PatternMatch, ): io_nodes = identify_rotary_emb_io_nodes(ep, graph_module, match) if io_nodes is None: diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 687a8761c6b..858a1595b52 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -10,6 +10,8 @@ import unittest from typing import Tuple +import executorch.backends.vulkan.test.utils as test_utils + import torch from executorch.backends.transforms.convert_dtype_pass import I64toI32 @@ -18,12 +20,23 @@ from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) + from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, ExecutorchProgramManager, + to_edge_transform_and_lower, +) +from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, ) -from torch.export import Dim, export, export_for_training, ExportedProgram +from executorch.extension.pytree import tree_flatten +from torch.export import Dim, export, ExportedProgram + from torchao.quantization.granularity import PerGroup from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -32,14 +45,10 @@ from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ from torchao.utils import unwrap_tensor_subclass -ctypes.CDLL("libvulkan.so.1") - - -from executorch.exir import to_edge_transform_and_lower -from executorch.extension.pybindings.portable_lib import ( # @manual - _load_for_executorch_from_buffer, -) -from executorch.extension.pytree import tree_flatten +try: + ctypes.CDLL("libvulkan.so.1") +except: + pass def lower_module( @@ -83,7 +92,7 @@ def quantize_and_lower_module( _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ) - program = export_for_training( + program = export( model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() @@ -95,6 +104,14 @@ def quantize_and_lower_module( program = export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + print(program.graph_module.graph) + # test = to_edge( + # program, + # compile_config=EdgeCompileConfig(_check_ir_validity=False), + # ) + # print(program.exported_program().graph_module) + # raise Exception("stop") + edge_program = to_edge_transform_and_lower( program, compile_config=edge_compile_config, @@ -129,37 +146,47 @@ def assert_outputs_equal( # Multiple outputs executor always returns tuple, even if there is one output self.assertTrue(len(ref_output) == len(model_output)) if first_output_only: - self.assertTrue( - torch.allclose( - model_output[0], - ref_output[0], + result = torch.allclose( + model_output[0], + ref_output[0], + atol=atol, + rtol=rtol, + equal_nan=equal_nan, + ) + if not result: + test_utils.print_tensor_comparison_errors( + model_output[0], ref_output[0], atol, rtol + ) + self.assertTrue(result) + else: + for i in range(len(ref_output)): + result = torch.allclose( + model_output[i], + ref_output[i], atol=atol, rtol=rtol, equal_nan=equal_nan, ) - ) - else: - for i in range(len(ref_output)): - self.assertTrue( - torch.allclose( - model_output[i], - ref_output[i], - atol=atol, - rtol=rtol, - equal_nan=equal_nan, + if not result: + print(f"\n=== Output {i} comparison failed ===") + test_utils.print_tensor_comparison_errors( + model_output[i], ref_output[i], atol, rtol ) - ) + self.assertTrue(result) else: # If one output, eager returns tensor while executor tuple of size 1 - self.assertTrue( - torch.allclose( - model_output[0], - ref_output, - atol=atol, - rtol=rtol, - equal_nan=equal_nan, - ) + result = torch.allclose( + model_output[0], + ref_output, + atol=atol, + rtol=rtol, + equal_nan=equal_nan, ) + if not result: + test_utils.print_tensor_comparison_errors( + model_output[0], ref_output, atol, rtol + ) + self.assertTrue(result) def check_no_delegation(self, et_program: ExecutorchProgramManager): self.assertEqual( @@ -190,6 +217,11 @@ def run_delegated_model_and_check_output( model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) ref_output = model(*sample_inputs) + print("ref_output") + print(ref_output) + print("model_output") + print(model_output) + self.assert_outputs_equal( model_output, ref_output, @@ -2388,3 +2420,294 @@ def apply_quantization(self): self.lower_module_and_test_output( quantized_linear_module_gemm, sample_inputs_gemm, atol=1e-2, rtol=1e-2 ) + + def test_vulkan_backend_xnnpack_pt2e_quantized_linear_sequence(self): + """ + Test a sequence of linear layers quantized with XNNPACK quantization config. + This test creates a module with multiple linear layers in sequence and applies + XNNPACK symmetric quantization to test the quantized model execution. + """ + + import logging + + import executorch.backends.vulkan.test.utils as test_utils + + logger: logging.Logger = logging.getLogger("") + logger.setLevel(logging.INFO) + + class LinearSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + self.linear3 = torch.nn.Linear(32, 16, bias=False) + + MAX = 0.75 + MIN = -0.25 + self.linear1.weight.data = test_utils.random_uniform_tensor( + self.linear1.weight.shape, MIN, MAX + ) + self.linear2.weight.data = test_utils.random_uniform_tensor( + self.linear2.weight.shape, MIN, MAX + ) + self.linear3.weight.data = test_utils.random_uniform_tensor( + self.linear3.weight.shape, MIN, MAX + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + # Create the module + linear_sequence_module = LinearSequenceModule() + + M = 32 + # Create sample inputs + sample_inputs = ( + ( + test_utils.random_uniform_tensor( + (M, linear_sequence_module.linear1.in_features), + -0.25, + 0.75, + ) + ), + ) + + # Create XNNPACK quantizer with symmetric quantization config + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + # Test the quantized module using the existing quantize_and_lower_module function + # Use higher tolerance since quantization introduces some error + edge_program = quantize_and_lower_module( + linear_sequence_module, sample_inputs, quantizer + ) + + et_program = edge_program.to_executorch() + self.check_vk_delegation(et_program) + + from executorch.backends.vulkan.test import utils as test_utils + + test_utils.save_executorch_program(et_program, "test") + + print(linear_sequence_module.linear1.weight) + + self.run_delegated_model_and_check_output( + et_program, + linear_sequence_module, + sample_inputs, + atol=1e-2, + rtol=1e-1, + ) + + def test_vulkan_backend_8da4w_quantized_linear_sequence(self): + """ + Test a sequence of linear layers quantized with 8da4w method. + This test creates a module with multiple linear layers in sequence and applies + 8da4w (8-bit dynamic activation, 4-bit weight) quantization to test the quantized + model execution using the Vulkan backend. + """ + from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ + from torchao.utils import unwrap_tensor_subclass + + class LinearSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(8, 4, bias=False) + self.linear2 = torch.nn.Linear(128, 64, bias=False) + self.linear3 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + x = self.linear1(x) + # x = self.linear2(x) + # x = self.linear3(x) + return x + + # Create the module + linear_sequence_module = LinearSequenceModule() + + # Apply 8da4w quantization with group_size=128 (default from quantize.py) + group_size = 128 + quantize_( + linear_sequence_module, + int8_dynamic_activation_int4_weight(group_size=group_size), + ) + linear_sequence_module = unwrap_tensor_subclass(linear_sequence_module) + + # Create sample inputs + sample_inputs = (torch.randn(size=(4, 8), dtype=torch.float32),) + + # Test the quantized module + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + linear_sequence_module, + sample_inputs, + atol=1e-1, + rtol=1e-1, + ) + + def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence(self): + """ + Test a sequence of convolution layers quantized with PT2E quantization. + This test creates a module with multiple Conv2d layers in sequence and applies + XNNPACK symmetric quantization to test the quantized model execution. + Similar to the linear sequence test but using convolution layers. + """ + + import executorch.backends.vulkan.test.utils as test_utils + + class ConvSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + bias=False, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=16, + out_channels=32, + kernel_size=3, + padding=1, + bias=False, + ) + self.conv3 = torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=3, + padding=1, + bias=False, + ) + + MAX = 0.75 + MIN = -0.25 + self.conv1.weight.data = test_utils.random_uniform_tensor( + self.conv1.weight.shape, MIN, MAX + ) + self.conv2.weight.data = test_utils.random_uniform_tensor( + self.conv2.weight.shape, MIN, MAX + ) + self.conv3.weight.data = test_utils.random_uniform_tensor( + self.conv3.weight.shape, MIN, MAX + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + + # Create the module + conv_sequence_module = ConvSequenceModule() + + input_tensor = test_utils.random_uniform_tensor( + (1, 3, 32, 32), + -0.25, + 0.75, + ) + + input_tensor = torch.ones((1, 3, 32, 32), dtype=torch.float32) + # Create sample inputs + sample_inputs = (input_tensor,) + + # Create XNNPACK quantizer with symmetric quantization config + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + # Test the quantized module using the existing quantize_and_lower_module function + # Use higher tolerance since quantization introduces some error + edge_program = quantize_and_lower_module( + conv_sequence_module, sample_inputs, quantizer + ) + + et_program = edge_program.to_executorch() + + test_utils.save_executorch_program(et_program, "test") + self.check_vk_delegation(et_program) + + self.run_delegated_model_and_check_output( + et_program, + conv_sequence_module, + sample_inputs, + atol=1e-2, + rtol=1e-1, + ) + + def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence_easy(self): + import executorch.backends.vulkan.test.utils as test_utils + + class ConvSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d( + in_channels=4, + out_channels=4, + kernel_size=3, + padding=1, + bias=False, + ) + + MAX = 0.75 + MIN = -0.25 + self.conv1.weight.data = test_utils.random_uniform_tensor( + self.conv1.weight.shape, MIN, MAX + ) + # self.conv1.weight.data = torch.ones(self.conv1.weight.shape) + + def forward(self, x): + x = self.conv1(x) + # x = self.conv2(x) + # x = self.conv3(x) + return x + + # Create the module + conv_sequence_module = ConvSequenceModule() + + input_tensor_shape = (1, 4, 5, 5) + input_tensor = test_utils.random_uniform_tensor( + input_tensor_shape, + -0.25, + 0.75, + ) + + input_tensor = torch.ones(input_tensor_shape, dtype=torch.float32) + # Create sample inputs + sample_inputs = (input_tensor,) + + # Create XNNPACK quantizer with symmetric quantization config + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + # Test the quantized module using the existing quantize_and_lower_module function + # Use higher tolerance since quantization introduces some error + edge_program = quantize_and_lower_module( + conv_sequence_module, sample_inputs, quantizer + ) + + et_program = edge_program.to_executorch() + + self.check_vk_delegation(et_program) + + self.run_delegated_model_and_check_output( + et_program, + conv_sequence_module, + sample_inputs, + atol=1e-2, + rtol=1e-1, + ) diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index 0e9ea6bc9d8..0a6884bfd6b 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -30,6 +30,13 @@ from torch.export import export, export_for_training +def random_uniform_tensor(shape, low=0.0, high=1.0, device=None, dtype=None): + if dtype is None: + dtype = torch.float32 + + return torch.empty(shape, device=device, dtype=dtype).uniform_(low, high) + + def export_model_to_vulkan( model, sample_inputs, @@ -108,6 +115,74 @@ def export_model_to_xnnpack(model, sample_inputs, dynamic_shapes=None): return executorch_program +def print_tensor_comparison_errors( + tensor1, tensor2, atol=1e-03, rtol=1e-03, max_errors=10 +): + """ + Print the first max_errors tensor indexes that exceed the absolute/relative tolerance + and the error at each of those locations. + + Args: + tensor1: First tensor to compare + tensor2: Second tensor to compare + atol: Absolute tolerance + rtol: Relative tolerance + max_errors: Maximum number of errors to print (default: 10) + """ + # Handle lists/tuples of tensors + if isinstance(tensor1, (list, tuple)) and isinstance(tensor2, (list, tuple)): + if len(tensor1) != len(tensor2): + print(f"Tensor count mismatch: {len(tensor1)} vs {len(tensor2)}") + return + + for i, (t1, t2) in enumerate(zip(tensor1, tensor2)): + print(f"\n=== Tensor {i} comparison ===") + print_tensor_comparison_errors(t1, t2, atol, rtol, max_errors) + return + + # Handle single tensor comparison + if not isinstance(tensor1, torch.Tensor) or not isinstance(tensor2, torch.Tensor): + print("Error: Both inputs must be torch.Tensor objects") + return + + if tensor1.shape != tensor2.shape: + print(f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}") + return + + # Calculate absolute and relative errors + abs_diff = torch.abs(tensor1 - tensor2) + rel_diff = abs_diff / ( + torch.abs(tensor2) + 1e-8 + ) # Add small epsilon to avoid division by zero + + # Find locations where tolerance is exceeded + tolerance_mask = (abs_diff > atol) & (rel_diff > rtol) + + if not tolerance_mask.any(): + print("All values are within tolerance") + return + + # Get indices where tolerance is exceeded + error_indices = torch.nonzero(tolerance_mask, as_tuple=False) + total_errors = error_indices.shape[0] + + print(f"Found {total_errors} values exceeding tolerance (atol={atol}, rtol={rtol})") + print(f"Showing first {min(max_errors, total_errors)} errors:") + print("Index -> tensor1_value, tensor2_value, abs_error, rel_error") + + # Print first max_errors locations + for i in range(min(max_errors, total_errors)): + idx = tuple(error_indices[i].tolist()) + val1 = tensor1[idx].item() + val2 = tensor2[idx].item() + abs_err = abs_diff[idx].item() + rel_err = rel_diff[idx].item() + + print( + f"{idx} -> {val1:.6f}, {val2:.6f}, abs_err={abs_err:.6f}, rel_err={rel_err:.6f}" + ) + + def check_outputs_equal( model_output, ref_output, atol=1e-03, rtol=1e-03, first_output_only=False ): @@ -123,19 +198,34 @@ def check_outputs_equal( if isinstance(ref_output, tuple) or isinstance(ref_output, list): # Multiple outputs executor always returns tuple, even if there is one output if len(ref_output) != len(model_output): + print_tensor_comparison_errors(model_output, ref_output, atol, rtol) return False if first_output_only: - return torch.allclose(model_output[0], ref_output[0], atol=atol, rtol=rtol) + result = torch.allclose( + model_output[0], ref_output[0], atol=atol, rtol=rtol + ) + if not result: + print_tensor_comparison_errors( + model_output[0], ref_output[0], atol, rtol + ) + return result else: for i in range(len(ref_output)): if not torch.allclose( model_output[i], ref_output[i], atol=atol, rtol=rtol ): + print(f"\n=== Output {i} comparison failed ===") + print_tensor_comparison_errors( + model_output[i], ref_output[i], atol, rtol + ) return False return True else: # If one output, eager returns tensor while executor tuple of size 1 - return torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) + result = torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) + if not result: + print_tensor_comparison_errors(model_output[0], ref_output, atol, rtol) + return result def run_and_check_output( diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 3b3e27acfbd..6c4737a9bc7 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -22,7 +22,7 @@ from executorch.exir.tensor import TensorSpec -from torch._export.utils import is_buffer, is_param +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch._subclasses.fake_tensor import FakeTensor @@ -31,6 +31,8 @@ from torch.export.exported_program import InputKind from torch.export.graph_signature import TensorArgument +TorchOpType = Union[EdgeOpOverload, torch._ops.OpOverload, str] + _DQ_OPS = { "dequantize_per_tensor.tensor", "dequantize_per_tensor.default", @@ -275,6 +277,45 @@ def node_comes_from_any_nn_module_in_set( return False +def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str: + if node is None: + return "" + if is_param(exp_prog, node): + return exp_prog.graph_signature.inputs_to_parameters[node.name] + elif is_buffer(exp_prog, node): + return exp_prog.graph_signature.inputs_to_buffers[node.name] + elif is_lifted_tensor_constant(exp_prog, node): + return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name] + else: + assert isinstance(node.target, str) + return node.target + + return "" + + +def find_dequant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: + """ + Search the direct users of the given node and return the first one that is a + dequantization op. Returns None if no dequantization op is found. + """ + for user in node.users: + if is_dequant_node(user): + return user + return None + + +def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: + """ + Search the direct users of the given node and return the first one that is a + quantization op. Returns None if no quantization op is found. + """ + for user in node.users: + if is_quant_node(user): + return user + + return None + + ## ## Memory Layout, Storage Type Determination ## @@ -1068,6 +1109,69 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]: return get_node_spec_attr(node, "etvk_node_repr", False) +## +## Graph Pattern Matching +## + + +def maybe_skip_q_dq_arg_chain( + arg: torch.fx.node.Argument, +) -> Tuple[Optional[torch.fx.Node], Optional[torch.fx.Node], Optional[torch.fx.Node]]: + """ + Check if the given node argument is part of a Quantize/Dequantize chain produced by + the quant workflow. If so, return the source tensor that is the input to the Q/DQ + chain and the quantize/dequantize nodes in the chain. Otherwise, return the argument + as is and None, None + """ + if not isinstance(arg, torch.fx.Node): + return None, None, None + + if is_dequant_node(arg): + dequant_node = arg + quant_node = dequant_node.args[0] + assert isinstance(quant_node, torch.fx.Node) + source_arg = quant_node.args[0] + assert isinstance(source_arg, torch.fx.Node) + return source_arg, quant_node, dequant_node + else: + return arg, None, None + + +def trace_args_until_placeholder( + node: torch.fx.node.Argument, max_search_depth: int = 4 +) -> Tuple[Optional[torch.fx.Node], List[torch.fx.Node]]: + """ + Trace through node.args[0] of a given initial node until a placeholder node is found + then return it and the list of nodes traversed. If no placeholder node is found, + returns None and an empty list. + """ + cur_node = node + search_depth = 0 + + if not isinstance(cur_node, torch.fx.Node): + return None, [] + + traversed = [cur_node] + while cur_node.op != "placeholder" and search_depth < max_search_depth: + # Break if cur_node has no args + if len(cur_node.args) == 0: + break + + cur_node = cur_node.args[0] + if not isinstance(cur_node, torch.fx.Node): + break + traversed.append(cur_node) + search_depth += 1 + + if not isinstance(cur_node, torch.fx.Node): + return None, [] + if cur_node.op != "placeholder": + return None, [] + + assert isinstance(cur_node, torch.fx.Node) + return cur_node, traversed + + ## ## Misc ## diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 5db5d7a4ff4..69d3cdef75d 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -19,6 +19,7 @@ ViewCopyToSqueezeUnsqueezePass, ) from executorch.backends.vulkan._passes import ( + FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, @@ -157,6 +158,7 @@ def preprocess( # noqa: C901 RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseQuantizedOpsTransform(program), + FoldQDQPass(program), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(),