From c46b635ff79254c8bb03831450f406bbd6dab103 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Wed, 2 Jul 2025 12:49:48 -0700 Subject: [PATCH] Deprecate tag qdq pass in xnnbackend (#12170) Summary: This diff decentralizes the q/dq implicit node tagging to individual partition configs instead of tagging it as part of backend pass. Changes in this diff: 1. Deprecate tag q dq pass 2. Remove all the places where this pass is used in the backend preprocess phase. 3. Decentralize the tagging to individual configs a. `generic_node_configs` will handle most of the non gemm nodes b. `gemm_configs` will handle gemm nodes c. channels last pass will add (copy q dq) or (dq copy q), tag the relevant nodes. d. tag q dq in conv1d unsqueeze pass. e. Tag q dq in compose cat 4. Deprecate configs.py where all the collection of nodes is maintained Fixes: https://github.com/pytorch/executorch/issues/11588 Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Differential Revision: D77055623 Pulled By: abhinaykukkadapu --- backends/xnnpack/_passes/TARGETS | 1 - backends/xnnpack/_passes/__init__.py | 4 - .../channels_last_tagged_reshape_pass.py | 36 ++- .../xnnpack/_passes/conv1d_unsqueeze_pass.py | 16 +- backends/xnnpack/_passes/decompose_cat.py | 8 +- .../xnnpack/_passes/tag_implicit_q_dq_pass.py | 217 ------------------ .../xnnpack/operators/op_quant_dequant.py | 8 +- backends/xnnpack/operators/quant_params.py | 11 +- backends/xnnpack/partition/TARGETS | 15 -- .../xnnpack/partition/config/gemm_configs.py | 7 +- .../partition/config/generic_node_configs.py | 29 ++- backends/xnnpack/partition/configs.py | 164 ------------- .../test_channels_last_tagged_reshape.py | 79 +++++++ .../passes/test_tag_implicit_q_dq_pass.py | 86 ------- backends/xnnpack/utils/quant_utils.py | 10 + backends/xnnpack/xnnpack_preprocess.py | 4 - 16 files changed, 178 insertions(+), 517 deletions(-) delete mode 100644 backends/xnnpack/_passes/tag_implicit_q_dq_pass.py delete mode 100644 backends/xnnpack/partition/configs.py delete mode 100644 backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index 972980570ec..5a038383f20 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -9,7 +9,6 @@ python_library( "//caffe2:torch", "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:lib", - "//executorch/backends/xnnpack/partition:configs", "//executorch/backends/xnnpack/partition:partitioner_graphs", "//executorch/backends/xnnpack/serialization:xnnpack_schema", "//executorch/backends/xnnpack/utils:xnnpack_utils", diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 4bf5bdfb079..5d7e388ef0a 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -25,9 +25,6 @@ FuseBatchNormWithConvPass, ) from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass -from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( - TagImplicitQDqPass, -) from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass from executorch.exir.pass_base import ExportPass @@ -70,7 +67,6 @@ def __init__( Conv1dUnsqueezePass, PReLUReshapePass, ChannelsLastTaggedReshapePass, - TagImplicitQDqPass, ] else: self.passes = passes diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 1d824d234ee..112cc26c075 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -8,7 +8,12 @@ import torch from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq +from executorch.backends.xnnpack.utils.quant_utils import ( + is_dequant, + is_dynamic_qdq, + is_tagged_as_implicit_q_dq, + tag_as_implicit_q_dq, +) from executorch.backends.xnnpack.utils.utils import is_param_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult @@ -144,7 +149,7 @@ def insert_copy_q_dq( target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(copy,) + q_params, ) - q.meta = copy.meta + q.meta = copy.meta.copy() with graph_module.graph.inserting_after(q): dq = self.create_call_function_node( @@ -152,9 +157,24 @@ def insert_copy_q_dq( target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(q,) + q_params, ) - dq.meta = q.meta + dq.meta = q.meta.copy() - after.replace_input_with(before, dq) + # Always tag q as implicit + tag_as_implicit_q_dq(q) + + # Tag relevant q/ dq nodes + # Ex: Original: G = conv -> q1 (Tag) -> dq1 (No Tag) -> output + # Insert (copy q dq pattern), G = conv -> q1 -> dq1 -> (copy q2 dq2)-> output + # if dq1 is not tagged as implicit, then tag dq2 and swap the dq1 and dq2 to simulate + # the pattern: G = conv -> q1 (Tag) -> (dq2 (Tag) copy q2 (Tag))-> dq1 (No Tag) -> output + + if is_dequant(before) and is_tagged_as_implicit_q_dq(before): + tag_as_implicit_q_dq(dq) + if is_dequant(before): + tag_as_implicit_q_dq(before) + + before.replace_all_uses_with(dq) + copy.replace_input_with(dq, before) def insert_dq_copy_q( self, @@ -170,7 +190,7 @@ def insert_dq_copy_q( target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(before,) + q_params, ) - dq.meta = before.meta + dq.meta = before.meta.copy() with graph_module.graph.inserting_after(copy): q = self.create_call_function_node( @@ -178,7 +198,11 @@ def insert_dq_copy_q( target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(copy,) + q_params, ) - q.meta = copy.meta + q.meta = copy.meta.copy() + + # Always tag q/dq as implicit + tag_as_implicit_q_dq(dq) + tag_as_implicit_q_dq(q) copy.replace_input_with(before, dq) after.replace_input_with(before, q) diff --git a/backends/xnnpack/_passes/conv1d_unsqueeze_pass.py b/backends/xnnpack/_passes/conv1d_unsqueeze_pass.py index 3173cab2746..7a6b031160a 100644 --- a/backends/xnnpack/_passes/conv1d_unsqueeze_pass.py +++ b/backends/xnnpack/_passes/conv1d_unsqueeze_pass.py @@ -8,7 +8,11 @@ import torch from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.backends.xnnpack.utils.quant_utils import ( + is_dequant, + is_quant, + tag_as_implicit_q_dq, +) from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult @@ -51,7 +55,10 @@ def insert_q_dq_pair( op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(), # We add the argument last ) - q.meta = anchor.meta + q.meta = anchor.meta.copy() + + # Tag q as implicit + tag_as_implicit_q_dq(q) with graph.inserting_after(q): dq = self.create_node( @@ -59,7 +66,10 @@ def insert_q_dq_pair( op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(q,) + q_params, ) - dq.meta = q.meta + dq.meta = q.meta.copy() + + # Tag dq as implicit + tag_as_implicit_q_dq(dq) anchor.replace_all_uses_with(dq) # We add this last so the replace all uses above does not replace the quqntized diff --git a/backends/xnnpack/_passes/decompose_cat.py b/backends/xnnpack/_passes/decompose_cat.py index b9057c43e16..41c8fe0083a 100644 --- a/backends/xnnpack/_passes/decompose_cat.py +++ b/backends/xnnpack/_passes/decompose_cat.py @@ -7,7 +7,11 @@ import logging import torch -from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.backends.xnnpack.utils.quant_utils import ( + is_dequant, + is_quant, + tag_as_implicit_q_dq, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -79,6 +83,7 @@ def call(self, graph_module: torch.fx.GraphModule): args=(node,) + q_params, kwargs=q_kwargs, ) + tag_as_implicit_q_dq(q_node) with gm.graph.inserting_after(q_node): dq_node = gm.graph.create_node( "call_function", @@ -86,6 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule): args=(q_node,) + q_params, kwargs=q_kwargs, ) + tag_as_implicit_q_dq(dq_node) remainder_concat_node.args = ( [dq_node] + remainder_nodes_to_concat, ) + node.args[1:] diff --git a/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py b/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py deleted file mode 100644 index dc488081025..00000000000 --- a/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import cast, List, Optional - -import torch -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.partition.configs import ( - SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET, - SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET, -) -from executorch.backends.xnnpack.utils.quant_utils import ( - is_dequant, - is_dynamic_qdq, - is_quant, -) -from executorch.backends.xnnpack.utils.utils import is_param_node -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult - - -class TagImplicitQDqPass(XNNPACKPass): - """ - This pass is used to tag "implicit" q/dq nodes, which should be ignored - during preprocessing. - - A q or dq node is deemed to be "implicit" if any of the following hold: - a) All of its inputs are constants (get_attr nodes or parameter (placeholder) nodes), - since (de)quantizing constants is done outside of executing the graph - b) It is the q or dq surrounding a "supported" group of nodes, ordered as - dq -> [supported group] -> q. A "supported" group is comprised of one of - the following: - ( i) A single supported op, from SUPPORTED_QUANT_OPS_SET, - ( ii) A single supported module, from SUPPORTED_QUANT_MODULES_SET, or - (iii) a chain of nodes matching a supported chain from - SUPPORTED_QUANT_CHAINS. - q/dq nodes which match this condition should be - ignore during preprocessing because they are only used as signaling for q - params of node inputs - c) It is a dq followed by aten.linear.default and then an output node. This - is because aten.linear.default is a special op corresponding with - dqlinear which doesn't necessarily have an q after it - """ - - _END_OF_CHAIN_MARKER = "END_OF_CHAIN" - # TODO: @salilsdesai Avoid hardcoding quant module chains here (instead get from quantizer) - SUPPORTED_QUANT_CHAINS = { - exir_ops.edge.aten.add.Tensor.name(): { - exir_ops.edge.aten.relu.default.name(): { - _END_OF_CHAIN_MARKER: True, - } - }, - exir_ops.edge.aten.convolution.default.name(): { - exir_ops.edge.aten.relu.default.name(): { - _END_OF_CHAIN_MARKER: True, - } - }, - exir_ops.edge.aten.mul.Tensor.name(): { - exir_ops.edge.aten.relu.default.name(): { - _END_OF_CHAIN_MARKER: True, - } - }, - exir_ops.edge.aten.sub.Tensor.name(): { - exir_ops.edge.aten.relu.default.name(): { - _END_OF_CHAIN_MARKER: True, - } - }, - exir_ops.edge.aten.linear.default.name(): { - exir_ops.edge.aten.relu.default.name(): { - _END_OF_CHAIN_MARKER: True, - } - }, - } - IS_IMPLICIT_Q_DQ_TAG = "IS_IMPLICIT_Q_DQ_TAG" - - def is_output_node(self, node: torch.fx.Node) -> bool: - return node.op == "output" - - def is_dynamically_quantized(self, node: torch.fx.Node) -> bool: - return is_dynamic_qdq(node) - - def is_supported_quant_op(self, node: torch.fx.Node) -> bool: - if node.op != "call_function": - return False - - op_name = cast(torch._ops.OpOverload, node.target).name() - - # Weight and Input should both be quantized - if op_name == exir_ops.edge.aten.convolution.default.name(): - if isinstance(node.args[1], torch.fx.Node): - # pyre-ignore Incompatible parameter type [6]: is_dequant expects Node - return is_dequant(node.args[1]) - - return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET - - def is_supported_quant_module(self, node: torch.fx.Node) -> bool: - is_supported = ( - "source_fn_stack" in node.meta - and node.meta["source_fn_stack"][-1][1] - in SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET - ) - if is_supported and self.is_supported_quant_op(node): - raise RuntimeError( - f"The same node should not be both a supported quant op and supported quant module: {node}" - ) - return is_supported - - def tag_as_implicit_q_dq(self, node: torch.fx.Node) -> None: - node.meta[TagImplicitQDqPass.IS_IMPLICIT_Q_DQ_TAG] = True - - @staticmethod - def is_tagged_as_implicit_q_dq(node: torch.fx.Node) -> bool: - return node.meta.get(TagImplicitQDqPass.IS_IMPLICIT_Q_DQ_TAG, False) - - def get_ending_implicit_q_nodes( - self, start_node: torch.fx.Node - ) -> Optional[List[torch.fx.Node]]: - """ - Returns a list of implicit q nodes which end the potential "supported" - group of nodes starting with start_node (which came after a dq), or None - if no such "supported" group exists. This list will either contain - one or zero elements. - """ - # If the node after the dq has multiple users then the dq can't be - # implicit - if len(start_node.users) != 1: - return None - - next_node = list(start_node.users)[0] - - if is_quant(next_node): - # Check if second_node (which is between dq and q nodes) is in - # supported quant ops or modules set - if self.is_supported_quant_op(start_node) or self.is_supported_quant_module( - start_node - ): - return [next_node] - elif self.is_output_node(next_node): - # if node following dq is output node - return None - else: - # Check if nodes between the dq node and the next q match - # a supported quant chain - available_chains = TagImplicitQDqPass.SUPPORTED_QUANT_CHAINS - current_node = start_node - while ( - # Not yet at end of chain in graph - not is_quant(current_node) - # Right number of users to continue chain - and len(current_node.users) == 1 - # Can continue following an available chain - and ( - current_node.op == "call_function" - and cast(torch._ops.OpOverload, current_node.target).name() - in available_chains - ) - ): - available_chains = available_chains[ - cast(torch._ops.OpOverload, current_node.target).name() - ] - current_node = list(current_node.users)[0] - - if ( - is_quant(current_node) - and TagImplicitQDqPass._END_OF_CHAIN_MARKER in available_chains - ): - # The chain of nodes between the dq and q nodes matches - # a supported quant chain - return [current_node] - - return None - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for first_node in graph_module.graph.nodes: - if (is_dequant(first_node) or is_quant(first_node)) and all( - is_param_node(self.exported_program, n) - for n in first_node.all_input_nodes - ): - # All of the q or dq node's inputs are constants - self.tag_as_implicit_q_dq(first_node) - continue - - if not is_dequant(first_node): - continue - - if len(first_node.users) == 0: - continue - - ending_implicit_q_nodes = [] - for user in first_node.users: - if self.is_dynamically_quantized(first_node): - # if the dq is a dynamic dq, then it is implicit - break - user_end_nodes = self.get_ending_implicit_q_nodes(user) - if user_end_nodes is None: - # This user isn't part of a "supported" group - ending_implicit_q_nodes = None - break - ending_implicit_q_nodes.extend(user_end_nodes) - - if ending_implicit_q_nodes is None: - # There was a user which isn't part of a "supported" group - # Don't tag anything as implicit for this iteration - continue - - self.tag_as_implicit_q_dq(first_node) - for node in ending_implicit_q_nodes: - self.tag_as_implicit_q_dq(node) - - # Since we are overriding "call", we need to call the parent's "call" - # to retrace the graph and regenerate metadata - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/xnnpack/operators/op_quant_dequant.py b/backends/xnnpack/operators/op_quant_dequant.py index 521a8b6475a..8a035849c06 100644 --- a/backends/xnnpack/operators/op_quant_dequant.py +++ b/backends/xnnpack/operators/op_quant_dequant.py @@ -7,9 +7,6 @@ from typing import Dict import torch -from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( - TagImplicitQDqPass, -) from executorch.backends.xnnpack.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -22,6 +19,7 @@ ) from executorch.backends.xnnpack.utils.quant_utils import ( is_per_channel_group, + is_tagged_as_implicit_q_dq, validate_quant_scales, validate_quant_zeropoints, ) @@ -86,7 +84,7 @@ def define_node( # check scales and zp are valid super().define_node(node, xnn_graph, vals_to_ids, debug_handle) - if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): + if not is_tagged_as_implicit_q_dq(node): dq_input = get_input_node(node, 0) input_quant_params = QuantParams.from_q_dq_node(node) # fp32 output @@ -137,7 +135,7 @@ def define_node( super().define_node(node, xnn_graph, vals_to_ids, debug_handle) q_input = get_input_node(node, 0) - if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): + if not is_tagged_as_implicit_q_dq(node): input_quant_params = QuantParams.from_q_dq_node(node) # fp32 input self.define_tensor(q_input, xnn_graph, vals_to_ids) diff --git a/backends/xnnpack/operators/quant_params.py b/backends/xnnpack/operators/quant_params.py index bdde1c59689..88a1f660f0e 100644 --- a/backends/xnnpack/operators/quant_params.py +++ b/backends/xnnpack/operators/quant_params.py @@ -9,9 +9,6 @@ from typing import cast, Optional, Union import torch -from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( - TagImplicitQDqPass, -) from executorch.backends.xnnpack.utils.quant_utils import ( extract_qdq_affine_op_args_for_decomposed_ops, is_affine_qdq, @@ -20,6 +17,7 @@ is_per_channel, is_per_channel_group, is_quant, + is_tagged_as_implicit_q_dq, ) from executorch.backends.xnnpack.utils.utils import ( check_or_raise, @@ -299,16 +297,13 @@ def from_inputs( cls, tensor_node: torch.fx.Node, ep: ExportedProgram ) -> Optional[QuantParams]: # tensor_node is quantized if it is produced by a dequant node - if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq( - tensor_node - ): + if is_dequant(tensor_node) and is_tagged_as_implicit_q_dq(tensor_node): dq_input = cast(torch.fx.Node, tensor_node.args[0]) if is_quant(dq_input): q_input = cast(torch.fx.Node, dq_input.args[0]) if is_param_node(ep, q_input): return cls.from_q_dq_node(dq_input) return cls.from_q_dq_node(tensor_node) - return None @classmethod @@ -317,7 +312,7 @@ def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]: if len(tensor_node.users) == 1: q = list(tensor_node.users.keys())[0] # Check if user is a q node - if is_quant(q) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(q): + if is_quant(q) and is_tagged_as_implicit_q_dq(q): return cls.from_q_dq_node(q) return None diff --git a/backends/xnnpack/partition/TARGETS b/backends/xnnpack/partition/TARGETS index bed4aa3ea45..6b81558e3be 100644 --- a/backends/xnnpack/partition/TARGETS +++ b/backends/xnnpack/partition/TARGETS @@ -12,7 +12,6 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ - ":configs", ":partitioner_graphs", "//executorch/backends/xnnpack:xnnpack_preprocess", "//executorch/backends/xnnpack/partition/config:xnnpack_partitioner_configs", @@ -24,20 +23,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "configs", - srcs = [ - "configs.py", - ], - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - deps = [ - "//executorch/exir:lib", - ], -) - runtime.python_library( name = "partitioner_graphs", srcs = glob([ diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index 67bccbc52d1..f65f9cb3398 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -25,6 +25,7 @@ is_per_tensor, is_qparam, is_quant, + tag_as_implicit_q_dq, ) from executorch.backends.xnnpack.utils.utils import ( get_input_node, @@ -136,6 +137,11 @@ def get_deps( valid_deps = valid_bias and valid_weight and valid_act and valid_output deps = list(chain(bias_deps, weight_deps, act_deps, output_deps)) + # Tag q/dq nodes as implicit q/dq nodes + for dep in deps: + if is_dequant(dep) or is_quant(dep): + tag_as_implicit_q_dq(dep) + return valid_deps, deps def _get_weight_deps( @@ -268,7 +274,6 @@ def _get_act_deps( if not is_quant(q_input): why(node, "Expected dequant input to be quant node") return (False, []) - gemm_deps.append(q_input) q_input_args = q_input.args if is_affine_qdq(q_input): diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 68f6d6579b3..e68c0b83c02 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -15,7 +15,11 @@ ConfigPrecisionType, XNNPartitionerConfig, ) -from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.backends.xnnpack.utils.quant_utils import ( + is_dequant, + is_quant, + tag_as_implicit_q_dq, +) from executorch.backends.xnnpack.utils.utils import get_input_node from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, @@ -54,10 +58,12 @@ def get_node_and_deps( quantized_deps.extend(node.all_input_nodes) - # check if quantized pattern has fused activation + # ensure the node has only one user to enforce quantized pattern + # (dq -> node -> fused act (optional) -> q) if len(node.users) != 1: return deps + # check if quantized pattern has fused activation node_output = list(node.users)[0] if ( node_output.op == "call_function" @@ -72,6 +78,15 @@ def get_node_and_deps( # Expected node --> fused_act (optional) --> dequant return deps + # Tag input nodes (dq nodes) as implicit q/dq nodes + for dq_input in node.all_input_nodes: + if is_dequant(dq_input): + tag_as_implicit_q_dq(dq_input) + + # Tag node_output (q node) as an implicit q/dq node + if is_quant(node_output): + tag_as_implicit_q_dq(node_output) + quantized_deps.append(node_output) return deps + quantized_deps @@ -83,6 +98,11 @@ class QuantizedPerTensorConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.STATIC_QUANT] + def get_node_and_deps( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + return [node] + class DeQuantizedPerTensorConfig(GenericNodePartitionerConfig): target_name = "dequantize_per_tensor.default" @@ -90,6 +110,11 @@ class DeQuantizedPerTensorConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.STATIC_QUANT] + def get_node_and_deps( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + return [node] + class HardtanhConfig(GenericNodePartitionerConfig): target_name = "hardtanh.default" diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py deleted file mode 100644 index eb31384c7ec..00000000000 --- a/backends/xnnpack/partition/configs.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from executorch.exir.dialects._ops import ops as exir_ops - -""" -** How to incorporate a new op into the XNNPACK Partitioner? ** - -[1] When the new edge op being added is direct descendent of a core-aten op, -and is also supported* by XNNPACK, prefer partitioning it via SUPPORTED_OPS -mechanism e.g. torch.add - -[2] When the new op being added is not a core-aten op, - -[2.1] If the original torch op is supported* by XNNPACK, prefer partitioning it -via SUPPORTED_MODULES. This will require "recomposing" the op before lowering -it to XNNPACK e.g. torch.nn.Linear. Make sure to include all variants of the -modules in the SUPPORTED_MODULES list. - -[2.2] If the original torch op is not supported by XNNPACK, then it is assumed -that out of all the decomposed core-aten ops, SUPPORTED_OPS will be lowered to -XNNPACK. - -* - Supported fully or partially. The partial support does not mean only few -ops from the decomposition but means only some variants of the op "modes" -possible with the arg combinations. -""" - -SUPPORTED_OPS = [ - exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.floor.default, - exir_ops.edge.aten.maximum.default, - exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.constant_pad_nd.default, - exir_ops.edge.aten.upsample_bilinear2d.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.max.dim, - exir_ops.edge.aten.max_pool2d_with_indices.default, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.sqrt.default, - exir_ops.edge.aten.ceil.default, - exir_ops.edge.aten.hardswish.default, - exir_ops.edge.aten.neg.default, - exir_ops.edge.aten.pow.Tensor_Scalar, - exir_ops.edge.aten.abs.default, - exir_ops.edge.aten._prelu_kernel.default, - exir_ops.edge.aten.slice_copy.Tensor, - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.sigmoid.default, - exir_ops.edge.aten._softmax.default, - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.elu.default, - exir_ops.edge.aten.avg_pool2d.default, - exir_ops.edge.aten.leaky_relu.default, - exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm - exir_ops.edge.aten.rsqrt.default, - exir_ops.edge.aten.log.default, - exir_ops.edge.aten.gelu.default, - exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.exp.default, -] - -SUPPORTED_MODULES = [ - torch.nn.Conv1d, - # TODO(T161981984) recomposed hardswish into a single node - torch.nn.Hardswish, # we need to recompose - torch.nn.Hardsigmoid, # we can handle decomposition - torch.nn.BatchNorm2d, - torch.nn.BatchNorm1d, - torch.nn.Conv2d, - torch.nn.ConvTranspose2d, - torch.nn.Linear, - torch.nn.functional.linear, - torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr -] - -# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support -SUPPORTED_QUANT_OPS = [ - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.slice_copy.Tensor, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.max_pool2d_with_indices.default, - exir_ops.edge.aten.max_pool2d.default, - exir_ops.edge.aten.constant_pad_nd.default, - exir_ops.edge.aten.elu.default, - exir_ops.edge.aten.t_copy.default, - exir_ops.edge.aten.leaky_relu.default, - exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm -] - -# This set is used to determine if an op is a supported Quantized Op. This is -# used to determine whether a quantization op is implicit or explicit. -SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = { - op.name() - for op in ( - SUPPORTED_QUANT_OPS - + [ - exir_ops.edge.aten._to_copy.default, - exir_ops.edge.aten.linear.default, - exir_ops.edge.aten.convolution.default, - ] - ) -} - -UNSUPPORTED_QUANT_MODULES = [ - torch.nn.Hardswish, - torch.nn.Hardsigmoid, -] - -# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support -SUPPORTED_QUANT_MODULES = [ - torch.nn.Linear, - torch.nn.functional.linear, - # TODO - T158982884 - # torch.ao.nn.quantized.reference.modules.linear.Linear, - torch.nn.Conv1d, - torch.nn.functional.conv1d, - torch.ao.nn.quantized.reference.modules.conv.Conv1d, - torch.nn.Conv2d, - torch.nn.functional.conv2d, - torch.ao.nn.quantized.reference.modules.conv.Conv2d, - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, -] - -SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES) - -# Modules which support dynamic quantization -# These already support dynamic shape. -SUPPORTED_DYN_QUANT_LINEAR_MODULES = [ - torch.nn.Linear, - torch.nn.functional.linear, -] - -SUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES - -# XNNPACK supports majority of shape dynamism, however some ops are -# explicitly static, so we maintain a set here to exclude them from -# dynamic shape support. -STATIC_OPS = [ - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.slice_copy.Tensor, -] - -STATIC_MODULES = [] diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index cfc409b4596..03515d8d420 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -7,6 +7,7 @@ import unittest import torch +from executorch.backends.test.harness.stages.stage import StageType from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( ChannelsLastTaggedReshapePass, ) @@ -17,6 +18,11 @@ OpSequencesAddConv2d, ) from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester +from executorch.backends.xnnpack.utils.quant_utils import ( + is_dequant, + is_quant, + is_tagged_as_implicit_q_dq, +) class TestChannelsLastTaggedReshapePass(unittest.TestCase): @@ -382,3 +388,76 @@ def test_three_outputs_model(self): x_cl = x.to(memory_format=torch.channels_last) self.run_tester(self.ThreeOutputsModelModule.eval(), (x_cl,)) + + class ConvQDQModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) + + def forward(self, x): + return self.conv(x) + + def _check_implicit_q_dq_tagging( + self, graph_module: torch.fx.GraphModule, expected_tagging: list[bool] + ): + q_dq_nodes = [] + for node in graph_module.graph.nodes: + if is_quant(node) or is_dequant(node): + q_dq_nodes.append(node) + + # Check that we have the expected number of nodes + self.assertEqual( + len(q_dq_nodes), + len(expected_tagging), + f"Expected {len(expected_tagging)} q/dq nodes but found {len(q_dq_nodes)}", + ) + + actual_tagging = [] + for node in q_dq_nodes: + is_tagged = is_tagged_as_implicit_q_dq(node) + actual_tagging.append(is_tagged) + + self.assertEqual( + actual_tagging, + expected_tagging, + f"Q/DQ node tagging mismatch. Expected: {expected_tagging}, Actual: {actual_tagging}", + ) + + def test_q_dq_nodes_around_copy_are_tagged(self): + # Create a model with conv operation + model = self.ConvQDQModule().eval() + input_tensor = torch.randn(1, 3, 8, 8) + + tester = ( + Tester(model, (input_tensor,)) + .quantize() + .export() + .to_edge() + .run_passes(self.PassStage) + .check( + [ + self.dequant_name, + self.quant_name, + self.dequant_name, + self.to_copy_name, + self.quant_name, + self.dequant_name, + self.conv_name, + self.quant_name, + self.dequant_name, + self.to_copy_name, + self.quant_name, + self.dequant_name, + ] + ) + ) + + artifact = tester.get_artifact(StageType.RUN_PASSES) + graph_module = artifact.exported_program().graph_module + + # Check implicit q/dq tagging + expected_tagging = [False, False, True, True, False, False, True, True, False] + self._check_implicit_q_dq_tagging(graph_module, expected_tagging) + + # Compare outputs + tester.run_method_and_compare_outputs() diff --git a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py b/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py deleted file mode 100644 index 2347122a180..00000000000 --- a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.backends.test.harness.stages import StageType -from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( - TagImplicitQDqPass, -) -from executorch.backends.xnnpack.test.tester import RunPasses, Tester -from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import ( - DuplicateDequantNodePass, -) -from executorch.exir.dialects._ops import ops as exir_ops - - -class TestTagImplicitQDq(unittest.TestCase): - PassStage = RunPasses([DuplicateDequantNodePass, TagImplicitQDqPass]) - - def setUp(self): - torch._dynamo.reset() - - class QDqModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - qparams = [0.12345, 0, -127, 127, torch.int8] - x = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( - x, *qparams - ) - x = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - x, *qparams - ) - x = torch.add(x, x) - x = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( - x, *qparams - ) - x = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - x, *qparams - ) - x = torch.mul(x, x) - x = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( - x, *qparams - ) - x = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - x, *qparams - ) - x = torch.add(x, x) - x = torch.mul(x, x) - return x - - def test_tag_implicit_q_dq_test(self): - inputs = (torch.randn(2, 3),) - artifact = ( - Tester(self.QDqModule(), inputs) - .export() - .to_edge() - .run_passes(self.PassStage) - .run_method_and_compare_outputs() - .get_artifact(StageType.RUN_PASSES) - ) - - for node in artifact.exported_program().module().graph.nodes: - print( - f"{node}: {node.meta.get(TagImplicitQDqPass.IS_IMPLICIT_Q_DQ_TAG, False)}" - ) - - # The six tagged nodes are: - # 1) The dq of the first add input - # 2) The dq of the second add input - # 3) The q of the add output - # 4) The dq of the first mul input - # 5) The dq of the second mul input - # 6) The q of the mul output - self.assertEqual( - sum( - node.meta.get(TagImplicitQDqPass.IS_IMPLICIT_Q_DQ_TAG, False) - for node in artifact.exported_program().module().graph.nodes - ), - 6, - ) diff --git a/backends/xnnpack/utils/quant_utils.py b/backends/xnnpack/utils/quant_utils.py index 12064899a7c..491c377cb5f 100644 --- a/backends/xnnpack/utils/quant_utils.py +++ b/backends/xnnpack/utils/quant_utils.py @@ -46,6 +46,16 @@ "dequantize_per_token.default", } +IS_IMPLICIT_Q_DQ_TAG = "IS_IMPLICIT_Q_DQ_TAG" + + +def tag_as_implicit_q_dq(node: torch.fx.Node) -> None: + node.meta[IS_IMPLICIT_Q_DQ_TAG] = True + + +def is_tagged_as_implicit_q_dq(node: torch.fx.Node) -> bool: + return node.meta.get(IS_IMPLICIT_Q_DQ_TAG, False) + def is_dynamic_qdq(node: torch.fx.Node) -> bool: # check has dynamic qdq name diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index d8892b179cf..05fb53a837d 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -12,9 +12,6 @@ from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass -from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( - TagImplicitQDqPass, -) from executorch.backends.xnnpack.operators.node_visitor import get_node_visitors from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( @@ -136,7 +133,6 @@ def preprocess( for spec in compile_specs: if spec.key == "dqlinear_partitioner": passes.append(ConvertToLinearPass) - passes.append(TagImplicitQDqPass) passes = passes if len(passes) > 0 else None # XNNPACK Delegate Specific Passes