diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index ed14574a8c8..0461c03ccb7 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -471,7 +471,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 pattern.partition_types(), ) for fused_partition in fused_partitions: - anchors = pattern.get_anchors(graph_module, fused_partition) + anchors, op_node = pattern.get_anchors(graph_module, fused_partition) if not anchors or anchors.empty: continue if any(self.is_fused(p.nodes) for p in fused_partition): @@ -512,13 +512,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 bias_inputs = [node.args[0] for node in dequants_biases] other_inputs = [node.args[idx] for node, idx in anchors.others] - # The node is the first index of the list and first of the tuple - anchor_output_node = anchors.output[0][0] + assert op_node is not None, "op_node is None" + quant_node = list(op_node.users.keys())[0] - assert len(anchor_output_node.users) == 1 - quant_node = list(anchor_output_node.users.keys())[0] - - with graph_module.graph.inserting_after(anchor_output_node): + with graph_module.graph.inserting_after(op_node): args = tuple( inputs_inputs + weights_inputs + other_inputs + bias_inputs ) @@ -532,7 +529,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 ) elif isinstance(pattern, CatPattern): args, kwargs = get_args_and_kwargs_cat( - inputs_inputs, other_inputs, anchor_output_node + inputs_inputs, other_inputs, op_node ) elif isinstance(pattern, ConvReluPatterns): # For ConvReLU, we are fusing Conv+ReLU @@ -563,7 +560,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_weights, bias_inputs, quant_node, - anchor_output_node, + op_node, ) elif isinstance(pattern, LinearPattern): args, kwargs = get_args_and_kwargs_linear( @@ -618,20 +615,28 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 inputs_inputs, dequants_inputs, quant_node, - anchor_output_node, + op_node, ) + fused = graph_module.graph.call_function( pattern.replacement_op(), args, kwargs, ) - fused.meta = quant_node.meta - quant_node.replace_all_uses_with(fused) + + if len(anchors.output) > 0: + fused.meta = quant_node.meta + quant_node.replace_all_uses_with(fused) + else: + fused.meta = op_node.meta + op_node.replace_all_uses_with(fused) + if op_node.op == "output": + _ = graph_module.graph.output((fused,)) legalize_graph(graph_module) graph_module.graph.eliminate_dead_code() - # pyre-fixme[7]: Incompatible return type graph_module.recompile() + return PassResult(graph_module, True) @classmethod # pyre-ignore[2]: Parameter `nodes` has no type specified diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 48458ba468a..4eae55502d7 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams @@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]: @abstractmethod def get_anchors( self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> Optional[PartitionAnchors]: + ) -> Tuple[PartitionAnchors, fx.Node]: pass @abstractmethod @@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... addmm_node = fused_partition[0].nodes[-1] @@ -101,11 +101,14 @@ def get_anchors( qscheme=torch.per_tensor_affine, ) - return PartitionAnchors( - inputs=[(addmm_node, 1)], - weights=[(addmm_node, 2)], - biases=[(addmm_node, 0, bias_qspec)], - output=[(addmm_node,)], + return ( + PartitionAnchors( + inputs=[(addmm_node, 1)], + weights=[(addmm_node, 2)], + biases=[(addmm_node, 0, bias_qspec)], + output=[(addmm_node,)], + ), + addmm_node, ) def replacement_op(self) -> OpOverload: @@ -118,7 +121,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... add_node = fused_partition[0].nodes[-1] @@ -129,15 +132,21 @@ def get_anchors( add_node.args[1], fx.Node ) if not is_tensor_add or len(add_node.kwargs) > 0: - return PartitionAnchors( - empty=True, + return ( + PartitionAnchors( + empty=True, + ), + add_node, ) - return PartitionAnchors( - inputs=[(add_node, 0), (add_node, 1)], - weights=[], - biases=[], - output=[(add_node,)], + return ( + PartitionAnchors( + inputs=[(add_node, 0), (add_node, 1)], + weights=[], + biases=[], + output=[(add_node,)], + ), + add_node, ) def replacement_op(self) -> OpOverload: @@ -150,15 +159,18 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... bmm_node = fused_partition[0].nodes[-1] - return PartitionAnchors( - inputs=[(bmm_node, 0), (bmm_node, 1)], - weights=[], - biases=[], - output=[(bmm_node,)], + return ( + PartitionAnchors( + inputs=[(bmm_node, 0), (bmm_node, 1)], + weights=[], + biases=[], + output=[(bmm_node,)], + ), + bmm_node, ) def replacement_op(self) -> OpOverload: @@ -171,7 +183,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... cat_node = fused_partition[0].nodes[-1] @@ -198,13 +210,16 @@ def get_anchors( ) ) - return PartitionAnchors( - inputs=args, - weights=[], - biases=[], - output=[ - (cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node))) - ], + return ( + PartitionAnchors( + inputs=args, + weights=[], + biases=[], + output=[ + (cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node))) + ], + ), + cat_node, ) def replacement_op(self) -> OpOverload: @@ -217,7 +232,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv1d_node = fused_partition[0].nodes[-1] @@ -238,12 +253,15 @@ def get_anchors( if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: bias = [(conv1d_node, 2, bias_qspec)] - return PartitionAnchors( - inputs=[(conv1d_node, 0)], - weights=[(conv1d_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(conv1d_node,)], + return ( + PartitionAnchors( + inputs=[(conv1d_node, 0)], + weights=[(conv1d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv1d_node,)], + ), + conv1d_node, ) def replacement_op(self) -> OpOverload: @@ -256,7 +274,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv2d_node = fused_partition[0].nodes[-1] @@ -277,12 +295,15 @@ def get_anchors( if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: bias = [(conv2d_node, 2, bias_qspec)] - return PartitionAnchors( - inputs=[(conv2d_node, 0)], - weights=[(conv2d_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(conv2d_node,)], + return ( + PartitionAnchors( + inputs=[(conv2d_node, 0)], + weights=[(conv2d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv2d_node,)], + ), + conv2d_node, ) def replacement_op(self) -> OpOverload: @@ -295,7 +316,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... layer_norm_node = fused_partition[0].nodes[-1] @@ -311,13 +332,16 @@ def get_anchors( # Weights are used in quantized mode by our kernel, so they are # passed in as others here along with the normalized shape. - return PartitionAnchors( - inputs=[(layer_norm_node, 0)], - weights=[], - biases=[], - # Ordering: normalized_shape, weights, bias - others=others, - output=[(layer_norm_node,)], + return ( + PartitionAnchors( + inputs=[(layer_norm_node, 0)], + weights=[], + biases=[], + # Ordering: normalized_shape, weights, bias + others=others, + output=[(layer_norm_node,)], + ), + layer_norm_node, ) def replacement_op(self) -> OpOverload: @@ -330,7 +354,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... linear_node = fused_partition[0].nodes[-1] @@ -351,12 +375,15 @@ def get_anchors( if len(linear_node.args) > 2: bias = [(linear_node, 2, bias_qspec)] - return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(linear_node,)], + return ( + PartitionAnchors( + inputs=[(linear_node, 0)], + weights=[(linear_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(linear_node,)], + ), + linear_node, ) def replacement_op(self) -> OpOverload: @@ -369,15 +396,18 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... matmul_node = fused_partition[0].nodes[-1] - return PartitionAnchors( - inputs=[(matmul_node, 0), (matmul_node, 1)], - weights=[], - biases=[], - output=[(matmul_node,)], + return ( + PartitionAnchors( + inputs=[(matmul_node, 0), (matmul_node, 1)], + weights=[], + biases=[], + output=[(matmul_node,)], + ), + matmul_node, ) def replacement_op(self) -> OpOverload: @@ -392,15 +422,18 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... relu_node = fused_partition[0].nodes[-1] - return PartitionAnchors( - inputs=[(relu_node, 0)], - weights=[], - biases=[], - output=[(relu_node,)], + return ( + PartitionAnchors( + inputs=[(relu_node, 0)], + weights=[], + biases=[], + output=[(relu_node,)], + ), + relu_node, ) def replacement_op(self) -> OpOverload: @@ -427,7 +460,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # The first node should be conv, the second should be relu # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv_node = fused_partition[0].nodes[-1] # Second to last node @@ -451,12 +484,15 @@ def get_anchors( if len(conv_node.args) > 2 and conv_node.args[2] is not None: bias = [(conv_node, 2, bias_qspec)] - return PartitionAnchors( - inputs=[(conv_node, 0)], - weights=[(conv_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(relu_node,)], # Output is from the relu node + return ( + PartitionAnchors( + inputs=[(conv_node, 0)], + weights=[(conv_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(relu_node,)], # Output is from the relu node + ), + relu_node, ) def replacement_op(self) -> OpOverload: @@ -494,15 +530,18 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... softmax_node = fused_partition[0].nodes[-1] - return PartitionAnchors( - inputs=[(softmax_node, 0)], - weights=[], - biases=[], - output=[(softmax_node,)], + return ( + PartitionAnchors( + inputs=[(softmax_node, 0)], + weights=[], + biases=[], + output=[(softmax_node,)], + ), + softmax_node, ) def replacement_op(self) -> OpOverload: diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index ad5f935173e..536b28f5cec 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -133,7 +133,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if not no_outside_users(fused_partition): continue - anchors = self.pattern.get_anchors(model, fused_partition) + anchors, _ = self.pattern.get_anchors(model, fused_partition) if not anchors or anchors.empty: continue if is_annotated(