2020 FuseFullThenReshapePass ,
2121 FuseMulIntoDequantPass ,
2222 FuseQuantDequantToRequantizePass ,
23- FuseTransposeOpPairsPass ,
23+ FuseTransposeOrPermuteOpPairsPass ,
2424)
2525from executorch .backends .cadence .aot .graph_builder import GraphBuilder
2626from executorch .backends .cadence .aot .pass_utils import count_node , op_counts_match
@@ -510,7 +510,7 @@ def test_fuse_then_transpose_pass(self):
510510 )
511511
512512
513- class TestFuseTransposeOpPairsPass (TestFusionPassesBase ):
513+ class TestFuseTransposeOrPermuteOpPairsPass (TestFusionPassesBase ):
514514 def _create_operator (
515515 self , builder : GraphBuilder , op : torch ._ops .OpOverload , x : ProxyValue
516516 ) -> ProxyValue :
@@ -536,17 +536,17 @@ def _create_operator(
536536 def test_fuse_transpose_pairs (self , op : torch ._ops .OpOverload ):
537537 # Create a graph with transpose -> quant -> transpose.
538538 builder = GraphBuilder ()
539- x = builder .placeholder ("x" , torch .randn (2 , 3 ))
540- transpose_node = builder .call_operator (
539+ x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 ))
540+ transpose_node0 = builder .call_operator (
541541 op = exir_ops .edge .aten .transpose_copy .int ,
542542 args = (x , 0 , 1 ),
543543 )
544- quant_node = self ._create_operator (builder , op , transpose_node )
545- transpose_node = builder .call_operator (
544+ quant_node = self ._create_operator (builder , op , transpose_node0 )
545+ transpose_node1 = builder .call_operator (
546546 op = exir_ops .edge .aten .transpose_copy .int ,
547- args = (quant_node , 0 , 1 ),
547+ args = (quant_node , 1 , 2 ),
548548 )
549- builder .output ([transpose_node ])
549+ builder .output ([transpose_node1 ])
550550 gm = builder .get_graph_module ()
551551 self .check_op_counts (
552552 gm ,
@@ -557,7 +557,7 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
557557 )
558558
559559 # Check that the pass fuses the two transpose ops.
560- fusion_pass_result = FuseTransposeOpPairsPass ()(gm )
560+ fusion_pass_result = FuseTransposeOrPermuteOpPairsPass ()(gm )
561561 self .assertIsNotNone (fusion_pass_result )
562562 gm_after_pass = fusion_pass_result .graph_module
563563 self .check_op_counts (
@@ -568,6 +568,47 @@ def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
568568 },
569569 )
570570
571+ @parameterized .expand (
572+ [
573+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
574+ exir_ops .edge .cadence .quantized_relu .per_tensor ,
575+ ],
576+ )
577+ def test_fuse_permute_pairs (self , op : torch ._ops .OpOverload ):
578+ # Create a graph with permute -> quant -> permute.
579+ builder = GraphBuilder ()
580+ x = builder .placeholder ("x" , torch .randn (8 , 2 , 3 , 4 ))
581+ permute_node0 = builder .call_operator (
582+ op = exir_ops .edge .aten .permute_copy .default ,
583+ args = (x , [0 , 3 , 1 , 2 ]),
584+ )
585+ quant_node = self ._create_operator (builder , op , permute_node0 )
586+ permute_node1 = builder .call_operator (
587+ op = exir_ops .edge .aten .permute_copy .default ,
588+ args = (quant_node , [0 , 2 , 3 , 1 ]),
589+ )
590+ builder .output ([permute_node1 ])
591+ gm = builder .get_graph_module ()
592+ self .check_op_counts (
593+ gm ,
594+ expected_op_counts = {
595+ exir_ops .edge .aten .permute_copy .default : 2 ,
596+ op : 1 ,
597+ },
598+ )
599+
600+ # Check that the pass fuses the two transpose ops.
601+ fusion_pass_result = FuseTransposeOrPermuteOpPairsPass ()(gm )
602+ self .assertIsNotNone (fusion_pass_result )
603+ gm_after_pass = fusion_pass_result .graph_module
604+ self .check_op_counts (
605+ gm_after_pass ,
606+ expected_op_counts = {
607+ exir_ops .edge .aten .permute_copy .default : 0 ,
608+ op : 1 ,
609+ },
610+ )
611+
571612 def test_no_fusion_for_transpose_pairs (self ):
572613 # Create a graph with transpose -> quant -> transpose.
573614 builder = GraphBuilder ()
@@ -595,7 +636,7 @@ def test_no_fusion_for_transpose_pairs(self):
595636 )
596637
597638 # No fusion.
598- gm_after_pass = FuseTransposeOpPairsPass ()(gm ).graph_module
639+ gm_after_pass = FuseTransposeOrPermuteOpPairsPass ()(gm ).graph_module
599640 self .check_op_counts (
600641 gm_after_pass ,
601642 expected_op_counts = {
@@ -604,6 +645,42 @@ def test_no_fusion_for_transpose_pairs(self):
604645 },
605646 )
606647
648+ def test_no_fusion_for_permute_pairs (self ):
649+ # Create a graph with permute -> quant -> permute.
650+ builder = GraphBuilder ()
651+ x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 ))
652+ permute_node = builder .call_operator (
653+ op = exir_ops .edge .aten .permute_copy .default ,
654+ args = (x , [2 , 0 , 1 ]),
655+ )
656+ quant_node = builder .call_operator (
657+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
658+ args = (permute_node , 1.2 , 3 , 0 , 127 , torch .int8 ),
659+ )
660+ permute_node = builder .call_operator (
661+ op = exir_ops .edge .aten .permute_copy .default ,
662+ args = (quant_node , [2 , 0 , 1 ]),
663+ )
664+ builder .output (permute_node )
665+ gm = builder .get_graph_module ()
666+ self .check_op_counts (
667+ gm ,
668+ expected_op_counts = {
669+ exir_ops .edge .aten .permute_copy .default : 2 ,
670+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 1 ,
671+ },
672+ )
673+
674+ # No fusion.
675+ gm_after_pass = FuseTransposeOrPermuteOpPairsPass ()(gm ).graph_module
676+ self .check_op_counts (
677+ gm_after_pass ,
678+ expected_op_counts = {
679+ exir_ops .edge .aten .permute_copy .default : 2 ,
680+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 1 ,
681+ },
682+ )
683+
607684 def test_fusion_for_forked_transposes (self ):
608685 # Create a graph with transpose -> quant -> transpose.
609686 builder = GraphBuilder ()
@@ -636,7 +713,7 @@ def test_fusion_for_forked_transposes(self):
636713 )
637714
638715 # Fuse the all the transpose ops.
639- gm_after_pass = FuseTransposeOpPairsPass ()(gm ).graph_module
716+ gm_after_pass = FuseTransposeOrPermuteOpPairsPass ()(gm ).graph_module
640717 self .check_op_counts (
641718 gm_after_pass ,
642719 expected_op_counts = {
0 commit comments