diff --git a/backends/cadence/aot/compiler_utils.py b/backends/cadence/aot/compiler_utils.py index cabfb120341..b55d388691f 100644 --- a/backends/cadence/aot/compiler_utils.py +++ b/backends/cadence/aot/compiler_utils.py @@ -201,13 +201,6 @@ def contains_node_with_matching_target( return any(node.target == op_target for node in nodes) -def is_quantized_tensor(x: torch.Tensor) -> bool: - """ - Return true if the tensor x is quantized - """ - return x.is_quantized - - def get_scale(x: torch.Tensor) -> torch.Tensor: """ Return the scale of a quantized tensor as a float32 tensor. diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index dcfc5fb82e4..7f493e1645d 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -15,17 +15,15 @@ import math import operator from operator import neg -from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple +from typing import cast, Dict, Iterable, Optional, Sequence, Tuple import torch import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( get_shape, get_tensor_from_attr, - get_transposed_dims, get_zero_point, is_node_with_op, - is_quantized_tensor, quantize_tensor_multiplier, ) from executorch.backends.cadence.aot.fuse_ops import ( @@ -772,186 +770,6 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(target, new_args, kwargs, meta) -# TODO(matthiascremon): this is a fuse op, not a replace op -class ReplaceConvWithChannelLastConv: - """ - Convolution op in pytorch expects NCHW layout for input, weight, and output - tensors. However, if the input and output to the convolution op are originally - in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse - the two permute ops with the convolution op, and call the NHWC layout - convolution op. - """ - - def __init__(self): - self.counter = 0 - self.graph_module = None - - def __call__(self, graph_module: torch.fx.GraphModule): - self.replace_conv_with_nhwc_conv(graph_module) - - def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool: - """ - Return true if the convolution input and output are connected to permute - ops, and the input/output to/from the permute ops is NHWC layout tensor. - """ - # There must only be a single user of the output node (which must be a - # permute/tranpsose op). The input of the convolution must be connected - # to a permute op, and that permute op should have a single user. - conv_inp = node.args[0] - assert isinstance(conv_inp, torch.fx.Node) - if len(node.users) != 1 or len(conv_inp.users) != 1: - return False - - # Get the input and output (permute/transpose) nodes of the convolution - conv_user = list(node.users.keys())[0] - assert isinstance(conv_user, torch.fx.Node) - pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user} - - # Any node in pt_nodes must not be a placeholder. - if contains_placeholder_or_param(pt_nodes): - return False - - # Determine if the convolution is 1d or 2d. The output tensor must be - # 3- or 4-dimensional - out_shape = get_shape(self.graph_module, node) - assert out_shape is not None - out_dims = len(out_shape) - assert out_dims in {3, 4}, "Only supports conv1d and conv2d" - conv1d = out_dims == 3 - - # Get the possible targets for the nodes in pt_nodes. Since conv1d has - # 3-dimensional input and output tensors, the nodes in pt_nodes could - # be either permute or transpose op. For conv2d, the nodes in pt_nodes - # must be permute ops. - p_target = exir_ops.edge.aten.permute_copy.default - t_target = exir_ops.edge.aten.transpose_copy.int - pt_targets = [p_target] + ([t_target] if conv1d else []) - - # If any node in pt_nodes is not permute op (or tranpose op for conv1d), - # bail. - if any(x.target not in pt_targets for x in pt_nodes): - return False - - # Now we need to determine the dimension permutations: - # If the input had NHWC layout, which was then permuted/transposed - # by a permute/transpose op to NCHW layout, the permutation must be - # [0, 3, 2, 1] (or [0, 2, 1] for conv1d). - # If the output had NCHW layout, and was then permuted to NHWC layout, - # the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d). - nhwc_permute_order = { - node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2], - list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1], - } - for x in pt_nodes: - order = ( - x.args[1] - if x.target == p_target - else get_transposed_dims(x, list(range(out_dims))) - ) - if order != nhwc_permute_order[x]: - return False - - return True - - def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule): - self.graph_module = graph_module - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in convolution nodes that have NHWC layout - if node.target not in { - exir_ops.edge.cadence.quantized_conv_nchw.default, - exir_ops.edge.cadence.convolution.default, - exir_ops.edge.cadence.quantized_transposed_conv.default, - exir_ops.edge.cadence.transposed_convolution.default, - } or not self.conv_layout_is_nhwc(node): - continue - - # Get the args of convolution op - args = list(node.args) - # The input is connected to a permute/transpose op that converts the - # NHWC layout to NCHW layout. The input of the permute op will become - # this convolution op's input. - in_tp = args[0] - args[0] = in_tp.args[0] - # The weight is in NHWC layout. Permute it to NHWC layout. - weight_tensor = get_tensor_from_attr(graph_module, args[1]) - assert isinstance(weight_tensor, torch.Tensor) - # We cannot directly permute a per-channel quantized tensor. We will - # dequantize it, permute the fp32 tensor, and then requantize the - # permuted tensor. - if ( - is_quantized_tensor(weight_tensor) - and weight_tensor.qscheme() == torch.per_channel_affine - ): - # We have already asserted during quantizing conv op that the - # quantization axis is 0. - dequant_weight = weight_tensor.dequantize() - dequant_weight = ( - dequant_weight.permute([0, 2, 1]) - if dequant_weight.dim() == 3 - else dequant_weight.permute([0, 2, 3, 1]) - ) - weight_tensor = torch.quantize_per_channel( - dequant_weight.contiguous(), - weight_tensor.q_per_channel_scales(), - weight_tensor.q_per_channel_zero_points(), - 0, - weight_tensor.dtype, - ) - else: - weight_tensor = ( - weight_tensor.permute([0, 2, 1]) - if weight_tensor.dim() == 3 - else weight_tensor.permute([0, 2, 3, 1]) - ) - # Make the weight tensor contiguous, since we have permuted it. - weight_tensor = weight_tensor.contiguous() - # Add the permuted weight into the graph, and update the weight in - # args. - with graph.inserting_before(node): - weight_name = f"_weight_nhwc_{self.counter}" - graph_module.register_buffer(weight_name, weight_tensor) - weight = graph.get_attr(weight_name) - args[1] = weight - - # The 'channel_last' arg is True. It is the last arg. - args[-1] = True - # Now update the convolution node args to mark it as NHWC convolution - node.args = tuple(args) - - # Replace all the uses of the permute op connected to the output op - # with this convolution. - out_tp = list(node.users.keys())[0] - out_tp.replace_all_uses_with(node) - node.meta = out_tp.meta - - # Erase the permute ops connected to the input and output of the - # convolution op. - graph.erase_node(in_tp) - graph.erase_node(out_tp) - self.counter += 1 - - graph_module.recompile() - - -# This pass needs to be reworked to be compatible with PT2. It is an optimization -# pass anyway, so move it to opt level 2. -# TODO: T213724613 update and improve this pass. -# @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceConvWithChannelLastConvPass(ExportPass): - """ - Replace the ATen convolution op with custom conv op with NCHW or NHWC layout - input tensors, depending on the presence of permute/transpose ops connected - to the input tensor. - """ - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - result = ReplaceAtenConvolutionWithCadenceConvolutionPass()(graph_module) - assert result is not None - ReplaceConvWithChannelLastConv()(result.graph_module) - return result - - @register_cadence_pass(CadencePassAttribute(opt_level=2)) class ReplaceTrivialConvWithLinear(ExportPass): """ @@ -1131,7 +949,7 @@ def transpose_dims( @register_cadence_pass(CadencePassAttribute(opt_level=3)) -class ForceChannelLastForConvPass(ExportPassWithTransposeHelper): +class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: shape = proxy.to_tensor().shape if len(shape) == 3: @@ -2441,9 +2259,8 @@ class CadenceReplaceOpsInGraph: ReplaceRepeatWithCatPass, ReplacePadWithCatPass, ReplaceConstantPadNdWithSlicePass, - ReplaceConvWithChannelLastConvPass, ReplaceAtenConvolutionWithCadenceConvolutionPass, - ForceChannelLastForConvPass, + ReplaceConvWithChannelLastConvPass, ReplaceTrivialConvWithLinear, ReplaceConvWithIm2RowAndLinear, ReplaceTransposedConvWithLinearPass, diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 11c90492da1..bd02cb0ae11 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -17,7 +17,6 @@ ) from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match from executorch.backends.cadence.aot.replace_ops import ( - ForceChannelLastForConvPass, MakeSliceAndCatDimOutermostPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAddMMWithLinearPass, @@ -25,6 +24,7 @@ ReplaceAtenConvolutionWithCadenceConvolutionPass, ReplaceConstantPadNdWithSlicePass, ReplaceConvolutionOptionalArgsWithConcreteArgsPass, + ReplaceConvWithChannelLastConvPass, ReplaceConvWithIm2RowAndLinear, ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, @@ -1454,7 +1454,7 @@ def test_replace_linear_like_conv(self) -> None: ) -class TestForceChannelLastForConvPass(unittest.TestCase): +class TestReplaceConvWithChannelLastConvPass(unittest.TestCase): def create_conv1d_graphmodule( self, channels_last: Optional[bool] = None ) -> torch.fx.GraphModule: @@ -1489,7 +1489,7 @@ def test_conv1d_default_channel_last(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0) # Apply replacement pass. - p = ForceChannelLastForConvPass() + p = ReplaceConvWithChannelLastConvPass() gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual( @@ -1514,7 +1514,7 @@ def test_conv1d_no_transpose_if_already_channel_last(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) # Apply replacement pass. - p = ForceChannelLastForConvPass() + p = ReplaceConvWithChannelLastConvPass() gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual( @@ -1566,7 +1566,7 @@ def test_convolution_default_channel_last(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) # Apply replacement pass. - p = ForceChannelLastForConvPass() + p = ReplaceConvWithChannelLastConvPass() gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual( @@ -1591,7 +1591,7 @@ def test_no_transpose_if_already_channel_last(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) # Apply replacement pass. - p = ForceChannelLastForConvPass() + p = ReplaceConvWithChannelLastConvPass() gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual( @@ -1692,7 +1692,7 @@ def test_quantized_convolution_default_channel_last(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) # Apply replacement pass. - p = ForceChannelLastForConvPass() + p = ReplaceConvWithChannelLastConvPass() gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual( @@ -1717,7 +1717,7 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: ) # Apply replacement pass. - p = ForceChannelLastForConvPass() + p = ReplaceConvWithChannelLastConvPass() gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual(