1717)
1818from executorch .backends .cadence .aot .pass_utils import count_node , op_counts_match
1919from executorch .backends .cadence .aot .replace_ops import (
20- ForceChannelLastForConvPass ,
20+ ReplaceConvWithChannelLastConvPass ,
2121 MakeSliceAndCatDimOutermostPass ,
2222 ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass ,
2323 ReplaceAddMMWithLinearPass ,
@@ -1454,7 +1454,7 @@ def test_replace_linear_like_conv(self) -> None:
14541454 )
14551455
14561456
1457- class TestForceChannelLastForConvPass (unittest .TestCase ):
1457+ class TestReplaceConvWithChannelLastConvPass (unittest .TestCase ):
14581458 def create_conv1d_graphmodule (
14591459 self , channels_last : Optional [bool ] = None
14601460 ) -> torch .fx .GraphModule :
@@ -1489,7 +1489,7 @@ def test_conv1d_default_channel_last(self) -> None:
14891489 self .assertEqual (count_node (gm , exir_ops .edge .aten .transpose_copy .int ), 0 )
14901490
14911491 # Apply replacement pass.
1492- p = ForceChannelLastForConvPass ()
1492+ p = ReplaceConvWithChannelLastConvPass ()
14931493 gm_after_replacement = p .call (gm ).graph_module
14941494 # Check that no replacement was made.
14951495 self .assertEqual (
@@ -1514,7 +1514,7 @@ def test_conv1d_no_transpose_if_already_channel_last(self) -> None:
15141514 self .assertEqual (count_node (gm , exir_ops .edge .cadence .convolution .default ), 1 )
15151515
15161516 # Apply replacement pass.
1517- p = ForceChannelLastForConvPass ()
1517+ p = ReplaceConvWithChannelLastConvPass ()
15181518 gm_after_replacement = p .call (gm ).graph_module
15191519 # Check that no replacement was made.
15201520 self .assertEqual (
@@ -1566,7 +1566,7 @@ def test_convolution_default_channel_last(self) -> None:
15661566 self .assertEqual (count_node (gm , exir_ops .edge .aten .permute_copy .default ), 0 )
15671567
15681568 # Apply replacement pass.
1569- p = ForceChannelLastForConvPass ()
1569+ p = ReplaceConvWithChannelLastConvPass ()
15701570 gm_after_replacement = p .call (gm ).graph_module
15711571 # Check that no replacement was made.
15721572 self .assertEqual (
@@ -1591,7 +1591,7 @@ def test_no_transpose_if_already_channel_last(self) -> None:
15911591 self .assertEqual (count_node (gm , exir_ops .edge .cadence .convolution .default ), 1 )
15921592
15931593 # Apply replacement pass.
1594- p = ForceChannelLastForConvPass ()
1594+ p = ReplaceConvWithChannelLastConvPass ()
15951595 gm_after_replacement = p .call (gm ).graph_module
15961596 # Check that no replacement was made.
15971597 self .assertEqual (
@@ -1671,7 +1671,7 @@ def test_quantized_convolution_default_channel_last(self) -> None:
16711671 self .assertEqual (count_node (gm , exir_ops .edge .aten .permute_copy .default ), 0 )
16721672
16731673 # Apply replacement pass.
1674- p = ForceChannelLastForConvPass ()
1674+ p = ReplaceConvWithChannelLastConvPass ()
16751675 gm_after_replacement = p .call (gm ).graph_module
16761676 # Check that no replacement was made.
16771677 self .assertEqual (
@@ -1702,7 +1702,7 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None:
17021702 )
17031703
17041704 # Apply replacement pass.
1705- p = ForceChannelLastForConvPass ()
1705+ p = ReplaceConvWithChannelLastConvPass ()
17061706 gm_after_replacement = p .call (gm ).graph_module
17071707 # Check that no replacement was made.
17081708 self .assertEqual (
0 commit comments