|
15 | 15 | import math
|
16 | 16 | import operator
|
17 | 17 | from operator import neg
|
18 |
| -from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple |
| 18 | +from typing import cast, Dict, Iterable, Optional, Sequence, Tuple |
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | import torch.fx
|
22 | 22 | from executorch.backends.cadence.aot.compiler_utils import (
|
23 | 23 | get_shape,
|
24 | 24 | get_tensor_from_attr,
|
25 |
| - get_transposed_dims, |
26 | 25 | get_zero_point,
|
27 | 26 | is_node_with_op,
|
28 |
| - is_quantized_tensor, |
29 | 27 | quantize_tensor_multiplier,
|
30 | 28 | )
|
31 | 29 | from executorch.backends.cadence.aot.fuse_ops import (
|
@@ -772,186 +770,6 @@ def call_operator(self, op, args, kwargs, meta):
|
772 | 770 | return super().call_operator(target, new_args, kwargs, meta)
|
773 | 771 |
|
774 | 772 |
|
775 |
| -# TODO(matthiascremon): this is a fuse op, not a replace op |
776 |
| -class ReplaceConvWithChannelLastConv: |
777 |
| - """ |
778 |
| - Convolution op in pytorch expects NCHW layout for input, weight, and output |
779 |
| - tensors. However, if the input and output to the convolution op are originally |
780 |
| - in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse |
781 |
| - the two permute ops with the convolution op, and call the NHWC layout |
782 |
| - convolution op. |
783 |
| - """ |
784 |
| - |
785 |
| - def __init__(self): |
786 |
| - self.counter = 0 |
787 |
| - self.graph_module = None |
788 |
| - |
789 |
| - def __call__(self, graph_module: torch.fx.GraphModule): |
790 |
| - self.replace_conv_with_nhwc_conv(graph_module) |
791 |
| - |
792 |
| - def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool: |
793 |
| - """ |
794 |
| - Return true if the convolution input and output are connected to permute |
795 |
| - ops, and the input/output to/from the permute ops is NHWC layout tensor. |
796 |
| - """ |
797 |
| - # There must only be a single user of the output node (which must be a |
798 |
| - # permute/tranpsose op). The input of the convolution must be connected |
799 |
| - # to a permute op, and that permute op should have a single user. |
800 |
| - conv_inp = node.args[0] |
801 |
| - assert isinstance(conv_inp, torch.fx.Node) |
802 |
| - if len(node.users) != 1 or len(conv_inp.users) != 1: |
803 |
| - return False |
804 |
| - |
805 |
| - # Get the input and output (permute/transpose) nodes of the convolution |
806 |
| - conv_user = list(node.users.keys())[0] |
807 |
| - assert isinstance(conv_user, torch.fx.Node) |
808 |
| - pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user} |
809 |
| - |
810 |
| - # Any node in pt_nodes must not be a placeholder. |
811 |
| - if contains_placeholder_or_param(pt_nodes): |
812 |
| - return False |
813 |
| - |
814 |
| - # Determine if the convolution is 1d or 2d. The output tensor must be |
815 |
| - # 3- or 4-dimensional |
816 |
| - out_shape = get_shape(self.graph_module, node) |
817 |
| - assert out_shape is not None |
818 |
| - out_dims = len(out_shape) |
819 |
| - assert out_dims in {3, 4}, "Only supports conv1d and conv2d" |
820 |
| - conv1d = out_dims == 3 |
821 |
| - |
822 |
| - # Get the possible targets for the nodes in pt_nodes. Since conv1d has |
823 |
| - # 3-dimensional input and output tensors, the nodes in pt_nodes could |
824 |
| - # be either permute or transpose op. For conv2d, the nodes in pt_nodes |
825 |
| - # must be permute ops. |
826 |
| - p_target = exir_ops.edge.aten.permute_copy.default |
827 |
| - t_target = exir_ops.edge.aten.transpose_copy.int |
828 |
| - pt_targets = [p_target] + ([t_target] if conv1d else []) |
829 |
| - |
830 |
| - # If any node in pt_nodes is not permute op (or tranpose op for conv1d), |
831 |
| - # bail. |
832 |
| - if any(x.target not in pt_targets for x in pt_nodes): |
833 |
| - return False |
834 |
| - |
835 |
| - # Now we need to determine the dimension permutations: |
836 |
| - # If the input had NHWC layout, which was then permuted/transposed |
837 |
| - # by a permute/transpose op to NCHW layout, the permutation must be |
838 |
| - # [0, 3, 2, 1] (or [0, 2, 1] for conv1d). |
839 |
| - # If the output had NCHW layout, and was then permuted to NHWC layout, |
840 |
| - # the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d). |
841 |
| - nhwc_permute_order = { |
842 |
| - node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2], |
843 |
| - list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1], |
844 |
| - } |
845 |
| - for x in pt_nodes: |
846 |
| - order = ( |
847 |
| - x.args[1] |
848 |
| - if x.target == p_target |
849 |
| - else get_transposed_dims(x, list(range(out_dims))) |
850 |
| - ) |
851 |
| - if order != nhwc_permute_order[x]: |
852 |
| - return False |
853 |
| - |
854 |
| - return True |
855 |
| - |
856 |
| - def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule): |
857 |
| - self.graph_module = graph_module |
858 |
| - graph = graph_module.graph |
859 |
| - for node in graph.nodes: |
860 |
| - # We are only interested in convolution nodes that have NHWC layout |
861 |
| - if node.target not in { |
862 |
| - exir_ops.edge.cadence.quantized_conv_nchw.default, |
863 |
| - exir_ops.edge.cadence.convolution.default, |
864 |
| - exir_ops.edge.cadence.quantized_transposed_conv.default, |
865 |
| - exir_ops.edge.cadence.transposed_convolution.default, |
866 |
| - } or not self.conv_layout_is_nhwc(node): |
867 |
| - continue |
868 |
| - |
869 |
| - # Get the args of convolution op |
870 |
| - args = list(node.args) |
871 |
| - # The input is connected to a permute/transpose op that converts the |
872 |
| - # NHWC layout to NCHW layout. The input of the permute op will become |
873 |
| - # this convolution op's input. |
874 |
| - in_tp = args[0] |
875 |
| - args[0] = in_tp.args[0] |
876 |
| - # The weight is in NHWC layout. Permute it to NHWC layout. |
877 |
| - weight_tensor = get_tensor_from_attr(graph_module, args[1]) |
878 |
| - assert isinstance(weight_tensor, torch.Tensor) |
879 |
| - # We cannot directly permute a per-channel quantized tensor. We will |
880 |
| - # dequantize it, permute the fp32 tensor, and then requantize the |
881 |
| - # permuted tensor. |
882 |
| - if ( |
883 |
| - is_quantized_tensor(weight_tensor) |
884 |
| - and weight_tensor.qscheme() == torch.per_channel_affine |
885 |
| - ): |
886 |
| - # We have already asserted during quantizing conv op that the |
887 |
| - # quantization axis is 0. |
888 |
| - dequant_weight = weight_tensor.dequantize() |
889 |
| - dequant_weight = ( |
890 |
| - dequant_weight.permute([0, 2, 1]) |
891 |
| - if dequant_weight.dim() == 3 |
892 |
| - else dequant_weight.permute([0, 2, 3, 1]) |
893 |
| - ) |
894 |
| - weight_tensor = torch.quantize_per_channel( |
895 |
| - dequant_weight.contiguous(), |
896 |
| - weight_tensor.q_per_channel_scales(), |
897 |
| - weight_tensor.q_per_channel_zero_points(), |
898 |
| - 0, |
899 |
| - weight_tensor.dtype, |
900 |
| - ) |
901 |
| - else: |
902 |
| - weight_tensor = ( |
903 |
| - weight_tensor.permute([0, 2, 1]) |
904 |
| - if weight_tensor.dim() == 3 |
905 |
| - else weight_tensor.permute([0, 2, 3, 1]) |
906 |
| - ) |
907 |
| - # Make the weight tensor contiguous, since we have permuted it. |
908 |
| - weight_tensor = weight_tensor.contiguous() |
909 |
| - # Add the permuted weight into the graph, and update the weight in |
910 |
| - # args. |
911 |
| - with graph.inserting_before(node): |
912 |
| - weight_name = f"_weight_nhwc_{self.counter}" |
913 |
| - graph_module.register_buffer(weight_name, weight_tensor) |
914 |
| - weight = graph.get_attr(weight_name) |
915 |
| - args[1] = weight |
916 |
| - |
917 |
| - # The 'channel_last' arg is True. It is the last arg. |
918 |
| - args[-1] = True |
919 |
| - # Now update the convolution node args to mark it as NHWC convolution |
920 |
| - node.args = tuple(args) |
921 |
| - |
922 |
| - # Replace all the uses of the permute op connected to the output op |
923 |
| - # with this convolution. |
924 |
| - out_tp = list(node.users.keys())[0] |
925 |
| - out_tp.replace_all_uses_with(node) |
926 |
| - node.meta = out_tp.meta |
927 |
| - |
928 |
| - # Erase the permute ops connected to the input and output of the |
929 |
| - # convolution op. |
930 |
| - graph.erase_node(in_tp) |
931 |
| - graph.erase_node(out_tp) |
932 |
| - self.counter += 1 |
933 |
| - |
934 |
| - graph_module.recompile() |
935 |
| - |
936 |
| - |
937 |
| -# This pass needs to be reworked to be compatible with PT2. It is an optimization |
938 |
| -# pass anyway, so move it to opt level 2. |
939 |
| -# TODO: T213724613 update and improve this pass. |
940 |
| -# @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
941 |
| -class ReplaceConvWithChannelLastConvPass(ExportPass): |
942 |
| - """ |
943 |
| - Replace the ATen convolution op with custom conv op with NCHW or NHWC layout |
944 |
| - input tensors, depending on the presence of permute/transpose ops connected |
945 |
| - to the input tensor. |
946 |
| - """ |
947 |
| - |
948 |
| - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
949 |
| - result = ReplaceAtenConvolutionWithCadenceConvolutionPass()(graph_module) |
950 |
| - assert result is not None |
951 |
| - ReplaceConvWithChannelLastConv()(result.graph_module) |
952 |
| - return result |
953 |
| - |
954 |
| - |
955 | 773 | @register_cadence_pass(CadencePassAttribute(opt_level=2))
|
956 | 774 | class ReplaceTrivialConvWithLinear(ExportPass):
|
957 | 775 | """
|
@@ -1131,7 +949,7 @@ def transpose_dims(
|
1131 | 949 |
|
1132 | 950 |
|
1133 | 951 | @register_cadence_pass(CadencePassAttribute(opt_level=3))
|
1134 |
| -class ForceChannelLastForConvPass(ExportPassWithTransposeHelper): |
| 952 | +class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): |
1135 | 953 | def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
|
1136 | 954 | shape = proxy.to_tensor().shape
|
1137 | 955 | if len(shape) == 3:
|
@@ -2441,9 +2259,8 @@ class CadenceReplaceOpsInGraph:
|
2441 | 2259 | ReplaceRepeatWithCatPass,
|
2442 | 2260 | ReplacePadWithCatPass,
|
2443 | 2261 | ReplaceConstantPadNdWithSlicePass,
|
2444 |
| - ReplaceConvWithChannelLastConvPass, |
2445 | 2262 | ReplaceAtenConvolutionWithCadenceConvolutionPass,
|
2446 |
| - ForceChannelLastForConvPass, |
| 2263 | + ReplaceConvWithChannelLastConvPass, |
2447 | 2264 | ReplaceTrivialConvWithLinear,
|
2448 | 2265 | ReplaceConvWithIm2RowAndLinear,
|
2449 | 2266 | ReplaceTransposedConvWithLinearPass,
|
|
0 commit comments