|
40 | 40 | ) |
41 | 41 | from executorch.exir.dialects._ops import ops as exir_ops |
42 | 42 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
43 | | -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue |
44 | | -from torch.fx.node import Argument |
| 43 | +from executorch.exir.pass_base import ExportPass, PassResult |
45 | 44 |
|
46 | 45 | # A map to represent ops that: |
47 | 46 | # (a) are functionally equivalent; and |
@@ -1020,131 +1019,197 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: |
1020 | 1019 | return dim |
1021 | 1020 |
|
1022 | 1021 |
|
1023 | | -class ExportPassWithTransposeHelper(ExportPass): |
1024 | | - def transpose_dims( |
1025 | | - self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int |
1026 | | - ) -> ProxyValue: |
1027 | | - """Helper function to transpose dims of a `proxy` with given `meta`.""" |
1028 | | - shape = proxy.data.shape |
| 1022 | +@register_cadence_pass(CadencePassAttribute(opt_level=3)) |
| 1023 | +class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface): |
| 1024 | + """ |
| 1025 | + Replace NCHW convolutions with NHWC (channel-last) convolutions by adding |
| 1026 | + transpose operations before and after the convolution. |
| 1027 | + """ |
| 1028 | + |
| 1029 | + @property |
| 1030 | + def targets(self) -> list[EdgeOpOverload]: |
| 1031 | + return [ |
| 1032 | + exir_ops.edge.cadence.conv1d.default, |
| 1033 | + exir_ops.edge.cadence.conv2d.default, |
| 1034 | + exir_ops.edge.cadence.conv3d.default, |
| 1035 | + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, |
| 1036 | + ] |
| 1037 | + |
| 1038 | + def _transpose_dims( |
| 1039 | + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int |
| 1040 | + ) -> torch.fx.Node: |
| 1041 | + """Helper function to transpose dims of a node.""" |
| 1042 | + shape = node.meta["val"].shape |
1029 | 1043 | dim0, dim1 = ( |
1030 | 1044 | canonicalize_transposed_dim(dim0, shape), |
1031 | 1045 | canonicalize_transposed_dim(dim1, shape), |
1032 | 1046 | ) |
1033 | 1047 | dim0, dim1 = min(dim0, dim1), max(dim0, dim1) |
1034 | | - return super().call_operator( |
1035 | | - exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta |
| 1048 | + transpose_node = graph.call_function( |
| 1049 | + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} |
1036 | 1050 | ) |
1037 | | - |
1038 | | - |
1039 | | -@register_cadence_pass(CadencePassAttribute(opt_level=3)) |
1040 | | -class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): |
1041 | | - def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: |
1042 | | - shape = proxy.to_tensor().shape |
| 1051 | + transpose_node.meta = node.meta |
| 1052 | + return transpose_node |
| 1053 | + |
| 1054 | + def _change_nchw_to_nhwc( |
| 1055 | + self, graph: torch.fx.Graph, node: torch.fx.Node |
| 1056 | + ) -> torch.fx.Node: |
| 1057 | + """Convert NCHW format to NHWC format.""" |
| 1058 | + shape = node.meta["val"].shape |
1043 | 1059 | if len(shape) == 3: |
1044 | | - return self.transpose_dims(proxy, meta, 1, -1) |
| 1060 | + return self._transpose_dims(graph, node, 1, -1) |
1045 | 1061 | indices = list(range(len(shape))) |
1046 | 1062 | permute_indices = [indices[0]] + indices[2:] + [indices[1]] |
1047 | | - return super().call_operator( |
1048 | | - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta |
| 1063 | + permute_node = graph.call_function( |
| 1064 | + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} |
1049 | 1065 | ) |
1050 | | - |
1051 | | - def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: |
1052 | | - shape = proxy.to_tensor().shape |
| 1066 | + permute_node.meta = node.meta |
| 1067 | + return permute_node |
| 1068 | + |
| 1069 | + def _change_nhwc_to_nchw( |
| 1070 | + self, graph: torch.fx.Graph, node: torch.fx.Node |
| 1071 | + ) -> torch.fx.Node: |
| 1072 | + """Convert NHWC format to NCHW format.""" |
| 1073 | + shape = node.meta["val"].shape |
1053 | 1074 | if len(shape) == 3: |
1054 | | - return self.transpose_dims(proxy, meta, 1, -1) |
| 1075 | + return self._transpose_dims(graph, node, 1, -1) |
1055 | 1076 | indices = list(range(len(shape))) |
1056 | 1077 | permute_indices = [indices[0], indices[-1]] + indices[1:-1] |
1057 | | - return super().call_operator( |
1058 | | - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta |
| 1078 | + permute_node = graph.call_function( |
| 1079 | + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} |
1059 | 1080 | ) |
| 1081 | + permute_node.meta = node.meta |
| 1082 | + return permute_node |
1060 | 1083 |
|
1061 | | - def call_operator( |
1062 | | - self, |
1063 | | - op, |
1064 | | - args: tuple[Argument, ...], |
1065 | | - kwargs: dict[str, Argument], |
1066 | | - meta: NodeMetadata, |
1067 | | - ) -> ProxyValue: |
1068 | | - if op not in { |
1069 | | - exir_ops.edge.cadence.conv1d.default, |
1070 | | - exir_ops.edge.cadence.conv2d.default, |
1071 | | - exir_ops.edge.cadence.conv3d.default, |
1072 | | - exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, |
1073 | | - }: |
1074 | | - return super().call_operator(op, args, kwargs, meta) |
1075 | | - |
1076 | | - quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor |
| 1084 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1085 | + assert isinstance(node.target, EdgeOpOverload) |
| 1086 | + quantized_op = node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor |
1077 | 1087 |
|
1078 | | - if not quantized_op and len(args) == 8 and args[-1] is True: |
1079 | | - # Already in NHWC layout. |
1080 | | - return super().call_operator(op, args, kwargs, meta) |
| 1088 | + # Check if already in NHWC layout |
| 1089 | + if not quantized_op and len(node.args) == 8 and node.args[-1] is True: |
| 1090 | + return False |
1081 | 1091 |
|
| 1092 | + # Determine the new op target |
1082 | 1093 | if quantized_op: |
1083 | 1094 | new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor |
1084 | 1095 | else: |
1085 | | - # Determine if 1D or 2D convolution based on op |
1086 | | - new_op = op |
| 1096 | + new_op = node.target |
1087 | 1097 |
|
1088 | | - input_proxy = cast(ProxyValue, args[0]) |
1089 | | - weight_proxy = cast(ProxyValue, args[1]) |
1090 | | - input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) |
1091 | | - weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) |
| 1098 | + graph = node.graph |
1092 | 1099 |
|
1093 | | - # Non-quantized ops still need to set the last optional argument to True. |
1094 | | - channel_last_arg = [] if quantized_op else [True] |
| 1100 | + # Get input and weight nodes |
| 1101 | + input_node = cast(torch.fx.Node, node.args[0]) |
| 1102 | + weight_node = cast(torch.fx.Node, node.args[1]) |
1095 | 1103 |
|
1096 | | - new_args = ( |
1097 | | - # Transposed input/weights. |
1098 | | - (input_proxy, weight_proxy) |
1099 | | - # All other args (bias, quant params, etc) |
1100 | | - + tuple(args[2:]) |
1101 | | - + tuple(channel_last_arg) |
1102 | | - ) |
1103 | | - output_proxy = super().call_operator(new_op, new_args, kwargs, meta) |
1104 | | - nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) |
1105 | | - return nchw_proxy |
| 1104 | + # Insert transpose operations before the node |
| 1105 | + with graph.inserting_before(node): |
| 1106 | + # Convert input from NCHW to NHWC |
| 1107 | + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) |
| 1108 | + # Convert weight from NCHW to NHWC |
| 1109 | + weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) |
| 1110 | + |
| 1111 | + # Non-quantized ops need to set the last optional argument to True |
| 1112 | + channel_last_arg = [] if quantized_op else [True] |
| 1113 | + |
| 1114 | + # Create new args with transposed input/weights |
| 1115 | + new_args = ( |
| 1116 | + (input_nhwc, weight_nhwc) |
| 1117 | + + tuple(node.args[2:]) |
| 1118 | + + tuple(channel_last_arg) |
| 1119 | + ) |
| 1120 | + |
| 1121 | + # Create the new conv operation |
| 1122 | + new_conv = graph.call_function(new_op, new_args, node.kwargs) |
| 1123 | + new_conv.meta = node.meta |
| 1124 | + |
| 1125 | + # Convert output back from NHWC to NCHW |
| 1126 | + nchw_output = self._change_nhwc_to_nchw(graph, new_conv) |
| 1127 | + |
| 1128 | + # Replace all uses with the final output |
| 1129 | + node.replace_all_uses_with(nchw_output) |
| 1130 | + return True |
1106 | 1131 |
|
1107 | 1132 |
|
1108 | 1133 | @register_cadence_pass(CadencePassAttribute(opt_level=3)) |
1109 | | -class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): |
1110 | | - def call_operator( |
1111 | | - self, |
1112 | | - op, |
1113 | | - args: tuple[Argument, ...], |
1114 | | - kwargs: dict[str, Argument], |
1115 | | - meta: NodeMetadata, |
1116 | | - ) -> ProxyValue: |
1117 | | - if op not in { |
| 1134 | +class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): |
| 1135 | + """ |
| 1136 | + Make the slice/cat dimension the outermost dimension by adding transpose |
| 1137 | + operations before and after the slice/cat operation. |
| 1138 | + """ |
| 1139 | + |
| 1140 | + @property |
| 1141 | + def targets(self) -> list[EdgeOpOverload]: |
| 1142 | + return [ |
1118 | 1143 | exir_ops.edge.aten.cat.default, |
1119 | 1144 | exir_ops.edge.aten.slice_copy.Tensor, |
1120 | | - }: |
1121 | | - return super().call_operator(op, args, kwargs, meta) |
1122 | | - dim = cast(int, args[1]) if len(args) > 1 else 0 |
1123 | | - output_shape = meta["val"].shape |
| 1145 | + ] |
| 1146 | + |
| 1147 | + def _transpose_dims( |
| 1148 | + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int |
| 1149 | + ) -> torch.fx.Node: |
| 1150 | + """Helper function to transpose dims of a node.""" |
| 1151 | + shape = node.meta["val"].shape |
| 1152 | + dim0, dim1 = ( |
| 1153 | + canonicalize_transposed_dim(dim0, shape), |
| 1154 | + canonicalize_transposed_dim(dim1, shape), |
| 1155 | + ) |
| 1156 | + dim0, dim1 = min(dim0, dim1), max(dim0, dim1) |
| 1157 | + transpose_node = graph.call_function( |
| 1158 | + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} |
| 1159 | + ) |
| 1160 | + transpose_node.meta = node.meta |
| 1161 | + return transpose_node |
| 1162 | + |
| 1163 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1164 | + # Get the dimension argument |
| 1165 | + dim = cast(int, node.args[1]) if len(node.args) > 1 else 0 |
| 1166 | + output_shape = node.meta["val"].shape |
| 1167 | + |
| 1168 | + # Canonicalize dim to be positive |
1124 | 1169 | if dim < 0: |
1125 | | - # Keep dim positive. |
1126 | 1170 | dim += len(output_shape) |
1127 | 1171 |
|
| 1172 | + # Not needed if dim is already outermost or all dims before it are 1 |
1128 | 1173 | if dim == 0 or math.prod(output_shape[:dim]) == 1: |
1129 | | - # Not needed if dim is already outermost or all dims before it are 1. |
1130 | | - return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) |
1131 | | - |
1132 | | - if op == exir_ops.edge.aten.slice_copy.Tensor: |
1133 | | - # Transpose -> slice. |
1134 | | - slice_args = ( |
1135 | | - self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), |
1136 | | - 0, |
1137 | | - ) + args[2:] |
1138 | | - new_op = super().call_operator(op, slice_args, kwargs, meta) |
1139 | | - else: |
1140 | | - # (Transpose input0, Transpose input1, ...) -> cat. |
1141 | | - cat_in_tensors = [ |
1142 | | - self.transpose_dims(t, meta, dim, 0) |
1143 | | - for t in cast(list[ProxyValue], args[0]) |
1144 | | - ] |
1145 | | - new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) |
1146 | | - # slice/cat -> transpose. |
1147 | | - return self.transpose_dims(new_op, meta, 0, dim) |
| 1174 | + return False |
| 1175 | + |
| 1176 | + graph = node.graph |
| 1177 | + |
| 1178 | + with graph.inserting_before(node): |
| 1179 | + if node.target == exir_ops.edge.aten.slice_copy.Tensor: |
| 1180 | + # Transpose input -> slice with dim=0 -> transpose back |
| 1181 | + input_node = cast(torch.fx.Node, node.args[0]) |
| 1182 | + transposed_input = self._transpose_dims(graph, input_node, dim, 0) |
| 1183 | + |
| 1184 | + # Create slice operation with dim=0 |
| 1185 | + slice_args = (transposed_input, 0) + node.args[2:] |
| 1186 | + sliced = graph.call_function( |
| 1187 | + exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs |
| 1188 | + ) |
| 1189 | + sliced.meta = node.meta |
| 1190 | + |
| 1191 | + # Transpose back |
| 1192 | + result = self._transpose_dims(graph, sliced, 0, dim) |
| 1193 | + else: |
| 1194 | + # Cat operation: transpose all inputs -> cat with dim=0 -> transpose back |
| 1195 | + cat_inputs = cast(list[torch.fx.Node], node.args[0]) |
| 1196 | + transposed_inputs = [ |
| 1197 | + self._transpose_dims(graph, t, dim, 0) |
| 1198 | + for t in cat_inputs |
| 1199 | + ] |
| 1200 | + |
| 1201 | + # Create cat operation with dim=0 |
| 1202 | + catted = graph.call_function( |
| 1203 | + exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs |
| 1204 | + ) |
| 1205 | + catted.meta = node.meta |
| 1206 | + |
| 1207 | + # Transpose back |
| 1208 | + result = self._transpose_dims(graph, catted, 0, dim) |
| 1209 | + |
| 1210 | + # Replace all uses with the final result |
| 1211 | + node.replace_all_uses_with(result) |
| 1212 | + return True |
1148 | 1213 |
|
1149 | 1214 |
|
1150 | 1215 | @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
|
0 commit comments