Skip to content

Commit 6bbd215

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update ReplaceConvWithChannelLastConvPass and MakeSliceAndCatDimOutermostPass to correctly set modified bit
Summary: As titled Differential Revision: D87880891
1 parent ff59e41 commit 6bbd215

File tree

2 files changed

+239
-110
lines changed

2 files changed

+239
-110
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 160 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@
4040
)
4141
from executorch.exir.dialects._ops import ops as exir_ops
4242
from executorch.exir.dialects.edge._ops import EdgeOpOverload
43-
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
44-
from torch.fx.node import Argument
43+
from executorch.exir.pass_base import ExportPass, PassResult
4544

4645
# A map to represent ops that:
4746
# (a) are functionally equivalent; and
@@ -1020,131 +1019,197 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int:
10201019
return dim
10211020

10221021

1023-
class ExportPassWithTransposeHelper(ExportPass):
1024-
def transpose_dims(
1025-
self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int
1026-
) -> ProxyValue:
1027-
"""Helper function to transpose dims of a `proxy` with given `meta`."""
1028-
shape = proxy.data.shape
1022+
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1023+
class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface):
1024+
"""
1025+
Replace NCHW convolutions with NHWC (channel-last) convolutions by adding
1026+
transpose operations before and after the convolution.
1027+
"""
1028+
1029+
@property
1030+
def targets(self) -> list[EdgeOpOverload]:
1031+
return [
1032+
exir_ops.edge.cadence.conv1d.default,
1033+
exir_ops.edge.cadence.conv2d.default,
1034+
exir_ops.edge.cadence.conv3d.default,
1035+
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
1036+
]
1037+
1038+
def _transpose_dims(
1039+
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
1040+
) -> torch.fx.Node:
1041+
"""Helper function to transpose dims of a node."""
1042+
shape = node.meta["val"].shape
10291043
dim0, dim1 = (
10301044
canonicalize_transposed_dim(dim0, shape),
10311045
canonicalize_transposed_dim(dim1, shape),
10321046
)
10331047
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1034-
return super().call_operator(
1035-
exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta
1048+
transpose_node = graph.call_function(
1049+
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
10361050
)
1037-
1038-
1039-
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1040-
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
1041-
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1042-
shape = proxy.to_tensor().shape
1051+
transpose_node.meta = node.meta
1052+
return transpose_node
1053+
1054+
def _change_nchw_to_nhwc(
1055+
self, graph: torch.fx.Graph, node: torch.fx.Node
1056+
) -> torch.fx.Node:
1057+
"""Convert NCHW format to NHWC format."""
1058+
shape = node.meta["val"].shape
10431059
if len(shape) == 3:
1044-
return self.transpose_dims(proxy, meta, 1, -1)
1060+
return self._transpose_dims(graph, node, 1, -1)
10451061
indices = list(range(len(shape)))
10461062
permute_indices = [indices[0]] + indices[2:] + [indices[1]]
1047-
return super().call_operator(
1048-
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1063+
permute_node = graph.call_function(
1064+
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
10491065
)
1050-
1051-
def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1052-
shape = proxy.to_tensor().shape
1066+
permute_node.meta = node.meta
1067+
return permute_node
1068+
1069+
def _change_nhwc_to_nchw(
1070+
self, graph: torch.fx.Graph, node: torch.fx.Node
1071+
) -> torch.fx.Node:
1072+
"""Convert NHWC format to NCHW format."""
1073+
shape = node.meta["val"].shape
10531074
if len(shape) == 3:
1054-
return self.transpose_dims(proxy, meta, 1, -1)
1075+
return self._transpose_dims(graph, node, 1, -1)
10551076
indices = list(range(len(shape)))
10561077
permute_indices = [indices[0], indices[-1]] + indices[1:-1]
1057-
return super().call_operator(
1058-
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1078+
permute_node = graph.call_function(
1079+
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
10591080
)
1081+
permute_node.meta = node.meta
1082+
return permute_node
10601083

1061-
def call_operator(
1062-
self,
1063-
op,
1064-
args: tuple[Argument, ...],
1065-
kwargs: dict[str, Argument],
1066-
meta: NodeMetadata,
1067-
) -> ProxyValue:
1068-
if op not in {
1069-
exir_ops.edge.cadence.conv1d.default,
1070-
exir_ops.edge.cadence.conv2d.default,
1071-
exir_ops.edge.cadence.conv3d.default,
1072-
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
1073-
}:
1074-
return super().call_operator(op, args, kwargs, meta)
1075-
1076-
quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
1084+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1085+
assert isinstance(node.target, EdgeOpOverload)
1086+
quantized_op = node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
10771087

1078-
if not quantized_op and len(args) == 8 and args[-1] is True:
1079-
# Already in NHWC layout.
1080-
return super().call_operator(op, args, kwargs, meta)
1088+
# Check if already in NHWC layout
1089+
if not quantized_op and len(node.args) == 8 and node.args[-1] is True:
1090+
return False
10811091

1092+
# Determine the new op target
10821093
if quantized_op:
10831094
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
10841095
else:
1085-
# Determine if 1D or 2D convolution based on op
1086-
new_op = op
1096+
new_op = node.target
10871097

1088-
input_proxy = cast(ProxyValue, args[0])
1089-
weight_proxy = cast(ProxyValue, args[1])
1090-
input_proxy = self.change_nchw_to_nhwc(input_proxy, meta)
1091-
weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta)
1098+
graph = node.graph
10921099

1093-
# Non-quantized ops still need to set the last optional argument to True.
1094-
channel_last_arg = [] if quantized_op else [True]
1100+
# Get input and weight nodes
1101+
input_node = cast(torch.fx.Node, node.args[0])
1102+
weight_node = cast(torch.fx.Node, node.args[1])
10951103

1096-
new_args = (
1097-
# Transposed input/weights.
1098-
(input_proxy, weight_proxy)
1099-
# All other args (bias, quant params, etc)
1100-
+ tuple(args[2:])
1101-
+ tuple(channel_last_arg)
1102-
)
1103-
output_proxy = super().call_operator(new_op, new_args, kwargs, meta)
1104-
nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta)
1105-
return nchw_proxy
1104+
# Insert transpose operations before the node
1105+
with graph.inserting_before(node):
1106+
# Convert input from NCHW to NHWC
1107+
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
1108+
# Convert weight from NCHW to NHWC
1109+
weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node)
1110+
1111+
# Non-quantized ops need to set the last optional argument to True
1112+
channel_last_arg = [] if quantized_op else [True]
1113+
1114+
# Create new args with transposed input/weights
1115+
new_args = (
1116+
(input_nhwc, weight_nhwc)
1117+
+ tuple(node.args[2:])
1118+
+ tuple(channel_last_arg)
1119+
)
1120+
1121+
# Create the new conv operation
1122+
new_conv = graph.call_function(new_op, new_args, node.kwargs)
1123+
new_conv.meta = node.meta
1124+
1125+
# Convert output back from NHWC to NCHW
1126+
nchw_output = self._change_nhwc_to_nchw(graph, new_conv)
1127+
1128+
# Replace all uses with the final output
1129+
node.replace_all_uses_with(nchw_output)
1130+
return True
11061131

11071132

11081133
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1109-
class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper):
1110-
def call_operator(
1111-
self,
1112-
op,
1113-
args: tuple[Argument, ...],
1114-
kwargs: dict[str, Argument],
1115-
meta: NodeMetadata,
1116-
) -> ProxyValue:
1117-
if op not in {
1134+
class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface):
1135+
"""
1136+
Make the slice/cat dimension the outermost dimension by adding transpose
1137+
operations before and after the slice/cat operation.
1138+
"""
1139+
1140+
@property
1141+
def targets(self) -> list[EdgeOpOverload]:
1142+
return [
11181143
exir_ops.edge.aten.cat.default,
11191144
exir_ops.edge.aten.slice_copy.Tensor,
1120-
}:
1121-
return super().call_operator(op, args, kwargs, meta)
1122-
dim = cast(int, args[1]) if len(args) > 1 else 0
1123-
output_shape = meta["val"].shape
1145+
]
1146+
1147+
def _transpose_dims(
1148+
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
1149+
) -> torch.fx.Node:
1150+
"""Helper function to transpose dims of a node."""
1151+
shape = node.meta["val"].shape
1152+
dim0, dim1 = (
1153+
canonicalize_transposed_dim(dim0, shape),
1154+
canonicalize_transposed_dim(dim1, shape),
1155+
)
1156+
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1157+
transpose_node = graph.call_function(
1158+
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1159+
)
1160+
transpose_node.meta = node.meta
1161+
return transpose_node
1162+
1163+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1164+
# Get the dimension argument
1165+
dim = cast(int, node.args[1]) if len(node.args) > 1 else 0
1166+
output_shape = node.meta["val"].shape
1167+
1168+
# Canonicalize dim to be positive
11241169
if dim < 0:
1125-
# Keep dim positive.
11261170
dim += len(output_shape)
11271171

1172+
# Not needed if dim is already outermost or all dims before it are 1
11281173
if dim == 0 or math.prod(output_shape[:dim]) == 1:
1129-
# Not needed if dim is already outermost or all dims before it are 1.
1130-
return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta)
1131-
1132-
if op == exir_ops.edge.aten.slice_copy.Tensor:
1133-
# Transpose -> slice.
1134-
slice_args = (
1135-
self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0),
1136-
0,
1137-
) + args[2:]
1138-
new_op = super().call_operator(op, slice_args, kwargs, meta)
1139-
else:
1140-
# (Transpose input0, Transpose input1, ...) -> cat.
1141-
cat_in_tensors = [
1142-
self.transpose_dims(t, meta, dim, 0)
1143-
for t in cast(list[ProxyValue], args[0])
1144-
]
1145-
new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta)
1146-
# slice/cat -> transpose.
1147-
return self.transpose_dims(new_op, meta, 0, dim)
1174+
return False
1175+
1176+
graph = node.graph
1177+
1178+
with graph.inserting_before(node):
1179+
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
1180+
# Transpose input -> slice with dim=0 -> transpose back
1181+
input_node = cast(torch.fx.Node, node.args[0])
1182+
transposed_input = self._transpose_dims(graph, input_node, dim, 0)
1183+
1184+
# Create slice operation with dim=0
1185+
slice_args = (transposed_input, 0) + node.args[2:]
1186+
sliced = graph.call_function(
1187+
exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs
1188+
)
1189+
sliced.meta = node.meta
1190+
1191+
# Transpose back
1192+
result = self._transpose_dims(graph, sliced, 0, dim)
1193+
else:
1194+
# Cat operation: transpose all inputs -> cat with dim=0 -> transpose back
1195+
cat_inputs = cast(list[torch.fx.Node], node.args[0])
1196+
transposed_inputs = [
1197+
self._transpose_dims(graph, t, dim, 0)
1198+
for t in cat_inputs
1199+
]
1200+
1201+
# Create cat operation with dim=0
1202+
catted = graph.call_function(
1203+
exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs
1204+
)
1205+
catted.meta = node.meta
1206+
1207+
# Transpose back
1208+
result = self._transpose_dims(graph, catted, 0, dim)
1209+
1210+
# Replace all uses with the final result
1211+
node.replace_all_uses_with(result)
1212+
return True
11481213

11491214

11501215
@register_cadence_pass(CadencePassAttribute(opt_level=2))

0 commit comments

Comments
 (0)