Skip to content

Commit 6ff2e29

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Remove outdated NCHW to NHWC pass and rename the current one to ReplaceConvWithChannelLastConvPass
Summary: The existing `ReplaceConvWithChannelLastConvPass` is a PT1 pass from years past, and currently does not do anything. The correct version of it is `ForceChannelLastForConvPass`, so we rename that one to `ReplaceConvWithChannelLastConvPass`. This should be a non-functional change. Differential Revision: D80185231
1 parent 50dc817 commit 6ff2e29

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def transpose_dims(
11271127

11281128

11291129
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1130-
class ForceChannelLastForConvPass(ExportPassWithTransposeHelper):
1130+
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
11311131
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
11321132
shape = proxy.to_tensor().shape
11331133
if len(shape) == 3:
@@ -2425,7 +2425,7 @@ class CadenceReplaceOpsInGraph:
24252425
ReplaceConstantPadNdWithSlicePass,
24262426
ReplaceConvWithChannelLastConvPass,
24272427
ReplaceAtenConvolutionWithCadenceConvolutionPass,
2428-
ForceChannelLastForConvPass,
2428+
ReplaceConvWithChannelLastConvPass,
24292429
ReplaceTrivialConvWithLinear,
24302430
ReplaceConvWithIm2RowAndLinear,
24312431
ReplaceTransposedConvWithLinearPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
1919
from executorch.backends.cadence.aot.replace_ops import (
20-
ForceChannelLastForConvPass,
20+
ReplaceConvWithChannelLastConvPass,
2121
MakeSliceAndCatDimOutermostPass,
2222
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
2323
ReplaceAddMMWithLinearPass,
@@ -1454,7 +1454,7 @@ def test_replace_linear_like_conv(self) -> None:
14541454
)
14551455

14561456

1457-
class TestForceChannelLastForConvPass(unittest.TestCase):
1457+
class TestReplaceConvWithChannelLastConvPass(unittest.TestCase):
14581458
def create_conv1d_graphmodule(
14591459
self, channels_last: Optional[bool] = None
14601460
) -> torch.fx.GraphModule:
@@ -1489,7 +1489,7 @@ def test_conv1d_default_channel_last(self) -> None:
14891489
self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0)
14901490

14911491
# Apply replacement pass.
1492-
p = ForceChannelLastForConvPass()
1492+
p = ReplaceConvWithChannelLastConvPass()
14931493
gm_after_replacement = p.call(gm).graph_module
14941494
# Check that no replacement was made.
14951495
self.assertEqual(
@@ -1514,7 +1514,7 @@ def test_conv1d_no_transpose_if_already_channel_last(self) -> None:
15141514
self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
15151515

15161516
# Apply replacement pass.
1517-
p = ForceChannelLastForConvPass()
1517+
p = ReplaceConvWithChannelLastConvPass()
15181518
gm_after_replacement = p.call(gm).graph_module
15191519
# Check that no replacement was made.
15201520
self.assertEqual(
@@ -1566,7 +1566,7 @@ def test_convolution_default_channel_last(self) -> None:
15661566
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
15671567

15681568
# Apply replacement pass.
1569-
p = ForceChannelLastForConvPass()
1569+
p = ReplaceConvWithChannelLastConvPass()
15701570
gm_after_replacement = p.call(gm).graph_module
15711571
# Check that no replacement was made.
15721572
self.assertEqual(
@@ -1591,7 +1591,7 @@ def test_no_transpose_if_already_channel_last(self) -> None:
15911591
self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
15921592

15931593
# Apply replacement pass.
1594-
p = ForceChannelLastForConvPass()
1594+
p = ReplaceConvWithChannelLastConvPass()
15951595
gm_after_replacement = p.call(gm).graph_module
15961596
# Check that no replacement was made.
15971597
self.assertEqual(
@@ -1671,7 +1671,7 @@ def test_quantized_convolution_default_channel_last(self) -> None:
16711671
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
16721672

16731673
# Apply replacement pass.
1674-
p = ForceChannelLastForConvPass()
1674+
p = ReplaceConvWithChannelLastConvPass()
16751675
gm_after_replacement = p.call(gm).graph_module
16761676
# Check that no replacement was made.
16771677
self.assertEqual(
@@ -1702,7 +1702,7 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None:
17021702
)
17031703

17041704
# Apply replacement pass.
1705-
p = ForceChannelLastForConvPass()
1705+
p = ReplaceConvWithChannelLastConvPass()
17061706
gm_after_replacement = p.call(gm).graph_module
17071707
# Check that no replacement was made.
17081708
self.assertEqual(

0 commit comments

Comments
 (0)