Skip to content

Commit 49805dd

Browse files
authored
Remove outdated NCHW to NHWC pass and rename the current one to ReplaceConvWithChannelLastConvPass
Differential Revision: D80185231 Pull Request resolved: #13420
1 parent 8ef9595 commit 49805dd

File tree

3 files changed

+11
-201
lines changed

3 files changed

+11
-201
lines changed

backends/cadence/aot/compiler_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,6 @@ def contains_node_with_matching_target(
201201
return any(node.target == op_target for node in nodes)
202202

203203

204-
def is_quantized_tensor(x: torch.Tensor) -> bool:
205-
"""
206-
Return true if the tensor x is quantized
207-
"""
208-
return x.is_quantized
209-
210-
211204
def get_scale(x: torch.Tensor) -> torch.Tensor:
212205
"""
213206
Return the scale of a quantized tensor as a float32 tensor.

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@
1515
import math
1616
import operator
1717
from operator import neg
18-
from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple
18+
from typing import cast, Dict, Iterable, Optional, Sequence, Tuple
1919

2020
import torch
2121
import torch.fx
2222
from executorch.backends.cadence.aot.compiler_utils import (
2323
get_shape,
2424
get_tensor_from_attr,
25-
get_transposed_dims,
2625
get_zero_point,
2726
is_node_with_op,
28-
is_quantized_tensor,
2927
quantize_tensor_multiplier,
3028
)
3129
from executorch.backends.cadence.aot.fuse_ops import (
@@ -772,186 +770,6 @@ def call_operator(self, op, args, kwargs, meta):
772770
return super().call_operator(target, new_args, kwargs, meta)
773771

774772

775-
# TODO(matthiascremon): this is a fuse op, not a replace op
776-
class ReplaceConvWithChannelLastConv:
777-
"""
778-
Convolution op in pytorch expects NCHW layout for input, weight, and output
779-
tensors. However, if the input and output to the convolution op are originally
780-
in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse
781-
the two permute ops with the convolution op, and call the NHWC layout
782-
convolution op.
783-
"""
784-
785-
def __init__(self):
786-
self.counter = 0
787-
self.graph_module = None
788-
789-
def __call__(self, graph_module: torch.fx.GraphModule):
790-
self.replace_conv_with_nhwc_conv(graph_module)
791-
792-
def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool:
793-
"""
794-
Return true if the convolution input and output are connected to permute
795-
ops, and the input/output to/from the permute ops is NHWC layout tensor.
796-
"""
797-
# There must only be a single user of the output node (which must be a
798-
# permute/tranpsose op). The input of the convolution must be connected
799-
# to a permute op, and that permute op should have a single user.
800-
conv_inp = node.args[0]
801-
assert isinstance(conv_inp, torch.fx.Node)
802-
if len(node.users) != 1 or len(conv_inp.users) != 1:
803-
return False
804-
805-
# Get the input and output (permute/transpose) nodes of the convolution
806-
conv_user = list(node.users.keys())[0]
807-
assert isinstance(conv_user, torch.fx.Node)
808-
pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user}
809-
810-
# Any node in pt_nodes must not be a placeholder.
811-
if contains_placeholder_or_param(pt_nodes):
812-
return False
813-
814-
# Determine if the convolution is 1d or 2d. The output tensor must be
815-
# 3- or 4-dimensional
816-
out_shape = get_shape(self.graph_module, node)
817-
assert out_shape is not None
818-
out_dims = len(out_shape)
819-
assert out_dims in {3, 4}, "Only supports conv1d and conv2d"
820-
conv1d = out_dims == 3
821-
822-
# Get the possible targets for the nodes in pt_nodes. Since conv1d has
823-
# 3-dimensional input and output tensors, the nodes in pt_nodes could
824-
# be either permute or transpose op. For conv2d, the nodes in pt_nodes
825-
# must be permute ops.
826-
p_target = exir_ops.edge.aten.permute_copy.default
827-
t_target = exir_ops.edge.aten.transpose_copy.int
828-
pt_targets = [p_target] + ([t_target] if conv1d else [])
829-
830-
# If any node in pt_nodes is not permute op (or tranpose op for conv1d),
831-
# bail.
832-
if any(x.target not in pt_targets for x in pt_nodes):
833-
return False
834-
835-
# Now we need to determine the dimension permutations:
836-
# If the input had NHWC layout, which was then permuted/transposed
837-
# by a permute/transpose op to NCHW layout, the permutation must be
838-
# [0, 3, 2, 1] (or [0, 2, 1] for conv1d).
839-
# If the output had NCHW layout, and was then permuted to NHWC layout,
840-
# the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d).
841-
nhwc_permute_order = {
842-
node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2],
843-
list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1],
844-
}
845-
for x in pt_nodes:
846-
order = (
847-
x.args[1]
848-
if x.target == p_target
849-
else get_transposed_dims(x, list(range(out_dims)))
850-
)
851-
if order != nhwc_permute_order[x]:
852-
return False
853-
854-
return True
855-
856-
def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule):
857-
self.graph_module = graph_module
858-
graph = graph_module.graph
859-
for node in graph.nodes:
860-
# We are only interested in convolution nodes that have NHWC layout
861-
if node.target not in {
862-
exir_ops.edge.cadence.quantized_conv_nchw.default,
863-
exir_ops.edge.cadence.convolution.default,
864-
exir_ops.edge.cadence.quantized_transposed_conv.default,
865-
exir_ops.edge.cadence.transposed_convolution.default,
866-
} or not self.conv_layout_is_nhwc(node):
867-
continue
868-
869-
# Get the args of convolution op
870-
args = list(node.args)
871-
# The input is connected to a permute/transpose op that converts the
872-
# NHWC layout to NCHW layout. The input of the permute op will become
873-
# this convolution op's input.
874-
in_tp = args[0]
875-
args[0] = in_tp.args[0]
876-
# The weight is in NHWC layout. Permute it to NHWC layout.
877-
weight_tensor = get_tensor_from_attr(graph_module, args[1])
878-
assert isinstance(weight_tensor, torch.Tensor)
879-
# We cannot directly permute a per-channel quantized tensor. We will
880-
# dequantize it, permute the fp32 tensor, and then requantize the
881-
# permuted tensor.
882-
if (
883-
is_quantized_tensor(weight_tensor)
884-
and weight_tensor.qscheme() == torch.per_channel_affine
885-
):
886-
# We have already asserted during quantizing conv op that the
887-
# quantization axis is 0.
888-
dequant_weight = weight_tensor.dequantize()
889-
dequant_weight = (
890-
dequant_weight.permute([0, 2, 1])
891-
if dequant_weight.dim() == 3
892-
else dequant_weight.permute([0, 2, 3, 1])
893-
)
894-
weight_tensor = torch.quantize_per_channel(
895-
dequant_weight.contiguous(),
896-
weight_tensor.q_per_channel_scales(),
897-
weight_tensor.q_per_channel_zero_points(),
898-
0,
899-
weight_tensor.dtype,
900-
)
901-
else:
902-
weight_tensor = (
903-
weight_tensor.permute([0, 2, 1])
904-
if weight_tensor.dim() == 3
905-
else weight_tensor.permute([0, 2, 3, 1])
906-
)
907-
# Make the weight tensor contiguous, since we have permuted it.
908-
weight_tensor = weight_tensor.contiguous()
909-
# Add the permuted weight into the graph, and update the weight in
910-
# args.
911-
with graph.inserting_before(node):
912-
weight_name = f"_weight_nhwc_{self.counter}"
913-
graph_module.register_buffer(weight_name, weight_tensor)
914-
weight = graph.get_attr(weight_name)
915-
args[1] = weight
916-
917-
# The 'channel_last' arg is True. It is the last arg.
918-
args[-1] = True
919-
# Now update the convolution node args to mark it as NHWC convolution
920-
node.args = tuple(args)
921-
922-
# Replace all the uses of the permute op connected to the output op
923-
# with this convolution.
924-
out_tp = list(node.users.keys())[0]
925-
out_tp.replace_all_uses_with(node)
926-
node.meta = out_tp.meta
927-
928-
# Erase the permute ops connected to the input and output of the
929-
# convolution op.
930-
graph.erase_node(in_tp)
931-
graph.erase_node(out_tp)
932-
self.counter += 1
933-
934-
graph_module.recompile()
935-
936-
937-
# This pass needs to be reworked to be compatible with PT2. It is an optimization
938-
# pass anyway, so move it to opt level 2.
939-
# TODO: T213724613 update and improve this pass.
940-
# @register_cadence_pass(CadencePassAttribute(opt_level=2))
941-
class ReplaceConvWithChannelLastConvPass(ExportPass):
942-
"""
943-
Replace the ATen convolution op with custom conv op with NCHW or NHWC layout
944-
input tensors, depending on the presence of permute/transpose ops connected
945-
to the input tensor.
946-
"""
947-
948-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
949-
result = ReplaceAtenConvolutionWithCadenceConvolutionPass()(graph_module)
950-
assert result is not None
951-
ReplaceConvWithChannelLastConv()(result.graph_module)
952-
return result
953-
954-
955773
@register_cadence_pass(CadencePassAttribute(opt_level=2))
956774
class ReplaceTrivialConvWithLinear(ExportPass):
957775
"""
@@ -1131,7 +949,7 @@ def transpose_dims(
1131949

1132950

1133951
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1134-
class ForceChannelLastForConvPass(ExportPassWithTransposeHelper):
952+
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
1135953
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1136954
shape = proxy.to_tensor().shape
1137955
if len(shape) == 3:
@@ -2441,9 +2259,8 @@ class CadenceReplaceOpsInGraph:
24412259
ReplaceRepeatWithCatPass,
24422260
ReplacePadWithCatPass,
24432261
ReplaceConstantPadNdWithSlicePass,
2444-
ReplaceConvWithChannelLastConvPass,
24452262
ReplaceAtenConvolutionWithCadenceConvolutionPass,
2446-
ForceChannelLastForConvPass,
2263+
ReplaceConvWithChannelLastConvPass,
24472264
ReplaceTrivialConvWithLinear,
24482265
ReplaceConvWithIm2RowAndLinear,
24492266
ReplaceTransposedConvWithLinearPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
)
1818
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
1919
from executorch.backends.cadence.aot.replace_ops import (
20-
ForceChannelLastForConvPass,
2120
MakeSliceAndCatDimOutermostPass,
2221
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
2322
ReplaceAddMMWithLinearPass,
2423
ReplaceAtenApproxGeluWithApproxGeluPass,
2524
ReplaceAtenConvolutionWithCadenceConvolutionPass,
2625
ReplaceConstantPadNdWithSlicePass,
2726
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
27+
ReplaceConvWithChannelLastConvPass,
2828
ReplaceConvWithIm2RowAndLinear,
2929
ReplaceEmptyTensorsWithFullPass,
3030
ReplaceFunctionallyEquivalentOpTargets,
@@ -1454,7 +1454,7 @@ def test_replace_linear_like_conv(self) -> None:
14541454
)
14551455

14561456

1457-
class TestForceChannelLastForConvPass(unittest.TestCase):
1457+
class TestReplaceConvWithChannelLastConvPass(unittest.TestCase):
14581458
def create_conv1d_graphmodule(
14591459
self, channels_last: Optional[bool] = None
14601460
) -> torch.fx.GraphModule:
@@ -1489,7 +1489,7 @@ def test_conv1d_default_channel_last(self) -> None:
14891489
self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0)
14901490

14911491
# Apply replacement pass.
1492-
p = ForceChannelLastForConvPass()
1492+
p = ReplaceConvWithChannelLastConvPass()
14931493
gm_after_replacement = p.call(gm).graph_module
14941494
# Check that no replacement was made.
14951495
self.assertEqual(
@@ -1514,7 +1514,7 @@ def test_conv1d_no_transpose_if_already_channel_last(self) -> None:
15141514
self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
15151515

15161516
# Apply replacement pass.
1517-
p = ForceChannelLastForConvPass()
1517+
p = ReplaceConvWithChannelLastConvPass()
15181518
gm_after_replacement = p.call(gm).graph_module
15191519
# Check that no replacement was made.
15201520
self.assertEqual(
@@ -1566,7 +1566,7 @@ def test_convolution_default_channel_last(self) -> None:
15661566
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
15671567

15681568
# Apply replacement pass.
1569-
p = ForceChannelLastForConvPass()
1569+
p = ReplaceConvWithChannelLastConvPass()
15701570
gm_after_replacement = p.call(gm).graph_module
15711571
# Check that no replacement was made.
15721572
self.assertEqual(
@@ -1591,7 +1591,7 @@ def test_no_transpose_if_already_channel_last(self) -> None:
15911591
self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
15921592

15931593
# Apply replacement pass.
1594-
p = ForceChannelLastForConvPass()
1594+
p = ReplaceConvWithChannelLastConvPass()
15951595
gm_after_replacement = p.call(gm).graph_module
15961596
# Check that no replacement was made.
15971597
self.assertEqual(
@@ -1692,7 +1692,7 @@ def test_quantized_convolution_default_channel_last(self) -> None:
16921692
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
16931693

16941694
# Apply replacement pass.
1695-
p = ForceChannelLastForConvPass()
1695+
p = ReplaceConvWithChannelLastConvPass()
16961696
gm_after_replacement = p.call(gm).graph_module
16971697
# Check that no replacement was made.
16981698
self.assertEqual(
@@ -1717,7 +1717,7 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None:
17171717
)
17181718

17191719
# Apply replacement pass.
1720-
p = ForceChannelLastForConvPass()
1720+
p = ReplaceConvWithChannelLastConvPass()
17211721
gm_after_replacement = p.call(gm).graph_module
17221722
# Check that no replacement was made.
17231723
self.assertEqual(

0 commit comments

Comments
 (0)