|
15 | 15 | import math |
16 | 16 | import operator |
17 | 17 | from operator import neg |
18 | | -from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple |
| 18 | +from typing import cast, Dict, Iterable, Optional, Sequence, Tuple |
19 | 19 |
|
20 | 20 | import torch |
21 | 21 | import torch.fx |
22 | 22 | from executorch.backends.cadence.aot.compiler_utils import ( |
23 | 23 | get_shape, |
24 | 24 | get_tensor_from_attr, |
25 | | - get_transposed_dims, |
26 | 25 | get_zero_point, |
27 | 26 | is_node_with_op, |
28 | | - is_quantized_tensor, |
29 | 27 | quantize_tensor_multiplier, |
30 | 28 | ) |
31 | 29 | from executorch.backends.cadence.aot.fuse_ops import ( |
@@ -772,186 +770,6 @@ def call_operator(self, op, args, kwargs, meta): |
772 | 770 | return super().call_operator(target, new_args, kwargs, meta) |
773 | 771 |
|
774 | 772 |
|
775 | | -# TODO(matthiascremon): this is a fuse op, not a replace op |
776 | | -class ReplaceConvWithChannelLastConv: |
777 | | - """ |
778 | | - Convolution op in pytorch expects NCHW layout for input, weight, and output |
779 | | - tensors. However, if the input and output to the convolution op are originally |
780 | | - in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse |
781 | | - the two permute ops with the convolution op, and call the NHWC layout |
782 | | - convolution op. |
783 | | - """ |
784 | | - |
785 | | - def __init__(self): |
786 | | - self.counter = 0 |
787 | | - self.graph_module = None |
788 | | - |
789 | | - def __call__(self, graph_module: torch.fx.GraphModule): |
790 | | - self.replace_conv_with_nhwc_conv(graph_module) |
791 | | - |
792 | | - def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool: |
793 | | - """ |
794 | | - Return true if the convolution input and output are connected to permute |
795 | | - ops, and the input/output to/from the permute ops is NHWC layout tensor. |
796 | | - """ |
797 | | - # There must only be a single user of the output node (which must be a |
798 | | - # permute/tranpsose op). The input of the convolution must be connected |
799 | | - # to a permute op, and that permute op should have a single user. |
800 | | - conv_inp = node.args[0] |
801 | | - assert isinstance(conv_inp, torch.fx.Node) |
802 | | - if len(node.users) != 1 or len(conv_inp.users) != 1: |
803 | | - return False |
804 | | - |
805 | | - # Get the input and output (permute/transpose) nodes of the convolution |
806 | | - conv_user = list(node.users.keys())[0] |
807 | | - assert isinstance(conv_user, torch.fx.Node) |
808 | | - pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user} |
809 | | - |
810 | | - # Any node in pt_nodes must not be a placeholder. |
811 | | - if contains_placeholder_or_param(pt_nodes): |
812 | | - return False |
813 | | - |
814 | | - # Determine if the convolution is 1d or 2d. The output tensor must be |
815 | | - # 3- or 4-dimensional |
816 | | - out_shape = get_shape(self.graph_module, node) |
817 | | - assert out_shape is not None |
818 | | - out_dims = len(out_shape) |
819 | | - assert out_dims in {3, 4}, "Only supports conv1d and conv2d" |
820 | | - conv1d = out_dims == 3 |
821 | | - |
822 | | - # Get the possible targets for the nodes in pt_nodes. Since conv1d has |
823 | | - # 3-dimensional input and output tensors, the nodes in pt_nodes could |
824 | | - # be either permute or transpose op. For conv2d, the nodes in pt_nodes |
825 | | - # must be permute ops. |
826 | | - p_target = exir_ops.edge.aten.permute_copy.default |
827 | | - t_target = exir_ops.edge.aten.transpose_copy.int |
828 | | - pt_targets = [p_target] + ([t_target] if conv1d else []) |
829 | | - |
830 | | - # If any node in pt_nodes is not permute op (or tranpose op for conv1d), |
831 | | - # bail. |
832 | | - if any(x.target not in pt_targets for x in pt_nodes): |
833 | | - return False |
834 | | - |
835 | | - # Now we need to determine the dimension permutations: |
836 | | - # If the input had NHWC layout, which was then permuted/transposed |
837 | | - # by a permute/transpose op to NCHW layout, the permutation must be |
838 | | - # [0, 3, 2, 1] (or [0, 2, 1] for conv1d). |
839 | | - # If the output had NCHW layout, and was then permuted to NHWC layout, |
840 | | - # the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d). |
841 | | - nhwc_permute_order = { |
842 | | - node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2], |
843 | | - list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1], |
844 | | - } |
845 | | - for x in pt_nodes: |
846 | | - order = ( |
847 | | - x.args[1] |
848 | | - if x.target == p_target |
849 | | - else get_transposed_dims(x, list(range(out_dims))) |
850 | | - ) |
851 | | - if order != nhwc_permute_order[x]: |
852 | | - return False |
853 | | - |
854 | | - return True |
855 | | - |
856 | | - def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule): |
857 | | - self.graph_module = graph_module |
858 | | - graph = graph_module.graph |
859 | | - for node in graph.nodes: |
860 | | - # We are only interested in convolution nodes that have NHWC layout |
861 | | - if node.target not in { |
862 | | - exir_ops.edge.cadence.quantized_conv_nchw.default, |
863 | | - exir_ops.edge.cadence.convolution.default, |
864 | | - exir_ops.edge.cadence.quantized_transposed_conv.default, |
865 | | - exir_ops.edge.cadence.transposed_convolution.default, |
866 | | - } or not self.conv_layout_is_nhwc(node): |
867 | | - continue |
868 | | - |
869 | | - # Get the args of convolution op |
870 | | - args = list(node.args) |
871 | | - # The input is connected to a permute/transpose op that converts the |
872 | | - # NHWC layout to NCHW layout. The input of the permute op will become |
873 | | - # this convolution op's input. |
874 | | - in_tp = args[0] |
875 | | - args[0] = in_tp.args[0] |
876 | | - # The weight is in NHWC layout. Permute it to NHWC layout. |
877 | | - weight_tensor = get_tensor_from_attr(graph_module, args[1]) |
878 | | - assert isinstance(weight_tensor, torch.Tensor) |
879 | | - # We cannot directly permute a per-channel quantized tensor. We will |
880 | | - # dequantize it, permute the fp32 tensor, and then requantize the |
881 | | - # permuted tensor. |
882 | | - if ( |
883 | | - is_quantized_tensor(weight_tensor) |
884 | | - and weight_tensor.qscheme() == torch.per_channel_affine |
885 | | - ): |
886 | | - # We have already asserted during quantizing conv op that the |
887 | | - # quantization axis is 0. |
888 | | - dequant_weight = weight_tensor.dequantize() |
889 | | - dequant_weight = ( |
890 | | - dequant_weight.permute([0, 2, 1]) |
891 | | - if dequant_weight.dim() == 3 |
892 | | - else dequant_weight.permute([0, 2, 3, 1]) |
893 | | - ) |
894 | | - weight_tensor = torch.quantize_per_channel( |
895 | | - dequant_weight.contiguous(), |
896 | | - weight_tensor.q_per_channel_scales(), |
897 | | - weight_tensor.q_per_channel_zero_points(), |
898 | | - 0, |
899 | | - weight_tensor.dtype, |
900 | | - ) |
901 | | - else: |
902 | | - weight_tensor = ( |
903 | | - weight_tensor.permute([0, 2, 1]) |
904 | | - if weight_tensor.dim() == 3 |
905 | | - else weight_tensor.permute([0, 2, 3, 1]) |
906 | | - ) |
907 | | - # Make the weight tensor contiguous, since we have permuted it. |
908 | | - weight_tensor = weight_tensor.contiguous() |
909 | | - # Add the permuted weight into the graph, and update the weight in |
910 | | - # args. |
911 | | - with graph.inserting_before(node): |
912 | | - weight_name = f"_weight_nhwc_{self.counter}" |
913 | | - graph_module.register_buffer(weight_name, weight_tensor) |
914 | | - weight = graph.get_attr(weight_name) |
915 | | - args[1] = weight |
916 | | - |
917 | | - # The 'channel_last' arg is True. It is the last arg. |
918 | | - args[-1] = True |
919 | | - # Now update the convolution node args to mark it as NHWC convolution |
920 | | - node.args = tuple(args) |
921 | | - |
922 | | - # Replace all the uses of the permute op connected to the output op |
923 | | - # with this convolution. |
924 | | - out_tp = list(node.users.keys())[0] |
925 | | - out_tp.replace_all_uses_with(node) |
926 | | - node.meta = out_tp.meta |
927 | | - |
928 | | - # Erase the permute ops connected to the input and output of the |
929 | | - # convolution op. |
930 | | - graph.erase_node(in_tp) |
931 | | - graph.erase_node(out_tp) |
932 | | - self.counter += 1 |
933 | | - |
934 | | - graph_module.recompile() |
935 | | - |
936 | | - |
937 | | -# This pass needs to be reworked to be compatible with PT2. It is an optimization |
938 | | -# pass anyway, so move it to opt level 2. |
939 | | -# TODO: T213724613 update and improve this pass. |
940 | | -# @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
941 | | -class ReplaceConvWithChannelLastConvPass(ExportPass): |
942 | | - """ |
943 | | - Replace the ATen convolution op with custom conv op with NCHW or NHWC layout |
944 | | - input tensors, depending on the presence of permute/transpose ops connected |
945 | | - to the input tensor. |
946 | | - """ |
947 | | - |
948 | | - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
949 | | - result = ReplaceAtenConvolutionWithCadenceConvolutionPass()(graph_module) |
950 | | - assert result is not None |
951 | | - ReplaceConvWithChannelLastConv()(result.graph_module) |
952 | | - return result |
953 | | - |
954 | | - |
955 | 773 | @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
956 | 774 | class ReplaceTrivialConvWithLinear(ExportPass): |
957 | 775 | """ |
@@ -1131,7 +949,7 @@ def transpose_dims( |
1131 | 949 |
|
1132 | 950 |
|
1133 | 951 | @register_cadence_pass(CadencePassAttribute(opt_level=3)) |
1134 | | -class ForceChannelLastForConvPass(ExportPassWithTransposeHelper): |
| 952 | +class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): |
1135 | 953 | def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: |
1136 | 954 | shape = proxy.to_tensor().shape |
1137 | 955 | if len(shape) == 3: |
@@ -2441,9 +2259,8 @@ class CadenceReplaceOpsInGraph: |
2441 | 2259 | ReplaceRepeatWithCatPass, |
2442 | 2260 | ReplacePadWithCatPass, |
2443 | 2261 | ReplaceConstantPadNdWithSlicePass, |
2444 | | - ReplaceConvWithChannelLastConvPass, |
2445 | 2262 | ReplaceAtenConvolutionWithCadenceConvolutionPass, |
2446 | | - ForceChannelLastForConvPass, |
| 2263 | + ReplaceConvWithChannelLastConvPass, |
2447 | 2264 | ReplaceTrivialConvWithLinear, |
2448 | 2265 | ReplaceConvWithIm2RowAndLinear, |
2449 | 2266 | ReplaceTransposedConvWithLinearPass, |
|
0 commit comments