88
99
1010import unittest
11+ from typing import List
1112
1213import executorch .backends .cadence .aot .ops_registrations # noqa
1314import torch
2021 FuseFullThenReshapePass ,
2122 FuseMulIntoDequantPass ,
2223 FuseQuantDequantToRequantizePass ,
23- FuseTransposeOpPairsPass ,
24+ FuseTransposeOrPermuteOpPairsPass ,
2425)
2526from executorch .backends .cadence .aot .graph_builder import GraphBuilder
2627from executorch .backends .cadence .aot .pass_utils import count_node , op_counts_match
@@ -509,6 +510,24 @@ def test_fuse_then_transpose_pass(self):
509510 )
510511
511512
513+ class TestFuseTransposeOrPermuteOpPairsPass (TestFusionPassesBase ):
514+ def _create_operator (
515+ self , builder : GraphBuilder , op : torch ._ops .OpOverload , x : ProxyValue
516+ ) -> ProxyValue :
517+ if op == exir_ops .edge .quantized_decomposed .quantize_per_tensor .default :
518+ return builder .call_operator (
519+ op = op ,
520+ args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
521+ )
522+ elif op == exir_ops .edge .cadence .quantized_relu .per_tensor :
523+ return builder .call_operator (
524+ op = op ,
525+ args = (x , 0 , 0 , 0 , 0 ),
526+ )
527+ else :
528+ raise ValueError (f"Unsupported op: { op } " )
529+
530+
512531class TestFuseTransposeOpPairsPass (TestFusionPassesBase ):
513532 def _create_operator (
514533 self , builder : GraphBuilder , op : torch ._ops .OpOverload , x : ProxyValue
@@ -528,83 +547,68 @@ def _create_operator(
528547
529548 @parameterized .expand (
530549 [
531- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
532- exir_ops .edge .cadence .quantized_relu .per_tensor ,
550+ # transpose -> quant -> same transpose => fuse
551+ (True , [0 , 1 ], True , [0 , 1 ], exir_ops .edge .quantized_decomposed .quantize_per_tensor .default , True ),
552+ # transpose -> quant -> same transpose => fuse (same with transpose dimensions in different order, and with different skip quant op)
553+ (True , [0 , 1 ], True , [1 , 0 ], exir_ops .edge .cadence .quantized_relu .per_tensor , True ),
554+ # transpose -> quant -> different transpose => don't fuse
555+ (True , [0 , 1 ], True , [0 , 2 ], exir_ops .edge .quantized_decomposed .quantize_per_tensor .default , False ),
556+ # permutation -> quant -> opposite permutation => fuse
557+ (False , [1 , 2 , 0 ], False , [2 , 0 , 1 ], exir_ops .edge .quantized_decomposed .quantize_per_tensor .default , True ),
558+ # permutation -> quant -> not the opposite permutation => don't fuse
559+ (False , [1 , 2 , 0 ], False , [1 , 2 , 0 ], exir_ops .edge .quantized_decomposed .quantize_per_tensor .default , False ),
560+ # transpose -> quant -> transpose as a permutation => fuse
561+ (True , [0 , 1 ], False , [1 , 0 , 2 ], exir_ops .edge .quantized_decomposed .quantize_per_tensor .default , True ),
562+ # transpose -> quant -> not opposite permutation => fuse
563+ (True , [0 , 1 ], False , [0 , 2 , 1 ], exir_ops .edge .quantized_decomposed .quantize_per_tensor .default , False ),
533564 ],
534565 )
535- def test_fuse_transpose_pairs (self , op : torch ._ops .OpOverload ):
536- # Create a graph with transpose -> quant -> transpose.
566+ def test_fuse_transpose_permute_pairs (self , is_op1_transpose : bool , perm1 : List [ int ], is_op2_transpose : bool , perm2 : List [ int ], quant_op : torch ._ops .OpOverload , expected_is_fused : bool ):
567+ # Create a graph with transpose/permute -> quant -> transpose/permute.
537568 builder = GraphBuilder ()
538- x = builder .placeholder ("x" , torch .randn (2 , 3 ))
539- transpose_node = builder .call_operator (
540- op = exir_ops .edge .aten .transpose_copy .int ,
541- args = (x , 0 , 1 ),
542- )
543- quant_node = self ._create_operator (builder , op , transpose_node )
544- transpose_node = builder .call_operator (
545- op = exir_ops .edge .aten .transpose_copy .int ,
546- args = (quant_node , 0 , 1 ),
547- )
548- builder .output ([transpose_node ])
569+ x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 ))
570+ op1 = exir_ops .edge .aten .transpose_copy .int if is_op1_transpose else exir_ops .edge .aten .permute_copy .default
571+ node1 = builder .call_operator (
572+ op = op1 ,
573+ args = (x , perm1 [0 ], perm1 [1 ]) if is_op1_transpose else (x , list (perm1 )),
574+ )
575+ quant_node = self ._create_operator (builder , quant_op , node1 )
576+ op2 = exir_ops .edge .aten .transpose_copy .int if is_op2_transpose else exir_ops .edge .aten .permute_copy .default
577+ node2 = builder .call_operator (
578+ op = op2 ,
579+ args = (quant_node , perm2 [0 ], perm2 [1 ]) if is_op2_transpose else (quant_node , list (perm2 )),
580+ )
581+ builder .output ([node2 ])
549582 gm = builder .get_graph_module ()
583+ exp_counts = {
584+ quant_op : 1 ,
585+ }
586+ exp_counts [op1 ] = 1
587+ exp_counts [op2 ] = exp_counts .get (op2 , 0 ) + 1
550588 self .check_op_counts (
551589 gm ,
552- expected_op_counts = {
553- exir_ops .edge .aten .transpose_copy .int : 2 ,
554- op : 1 ,
555- },
590+ # pyre-fixme[6]: Incompatible parameter type
591+ expected_op_counts = exp_counts
556592 )
557593
558- # Check that the pass fuses the two transpose ops.
559- fusion_pass_result = FuseTransposeOpPairsPass ()(gm )
594+ # Check that the pass fuses the two transpose/permute ops.
595+ fusion_pass_result = FuseTransposeOrPermuteOpPairsPass ()(gm )
560596 self .assertIsNotNone (fusion_pass_result )
561597 gm_after_pass = fusion_pass_result .graph_module
598+ if expected_is_fused :
599+ exp_counts [op1 ] = 0
600+ exp_counts [op2 ] = 0
562601 self .check_op_counts (
563602 gm_after_pass ,
564- expected_op_counts = {
565- exir_ops .edge .aten .transpose_copy .int : 0 ,
566- op : 1 ,
567- },
568- )
569-
570- def test_no_fusion_for_transpose_pairs (self ):
571- # Create a graph with transpose -> quant -> transpose.
572- builder = GraphBuilder ()
573- x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 ))
574- transpose_node = builder .call_operator (
575- op = exir_ops .edge .aten .transpose_copy .int ,
576- args = (x , 0 , 1 ),
577- )
578- quant_node = builder .call_operator (
579- op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
580- args = (transpose_node , 1.2 , 3 , 0 , 127 , torch .int8 ),
581- )
582- transpose_node = builder .call_operator (
583- op = exir_ops .edge .aten .transpose_copy .int ,
584- args = (quant_node , 1 , 2 ),
585- )
586- builder .output (transpose_node )
587- gm = builder .get_graph_module ()
588- self .check_op_counts (
589- gm ,
590- expected_op_counts = {
591- exir_ops .edge .aten .transpose_copy .int : 2 ,
592- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 1 ,
593- },
594- )
595-
596- # No fusion.
597- gm_after_pass = FuseTransposeOpPairsPass ()(gm ).graph_module
598- self .check_op_counts (
599- gm_after_pass ,
600- expected_op_counts = {
601- exir_ops .edge .aten .transpose_copy .int : 2 ,
602- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 1 ,
603- },
603+ # pyre-fixme[6]: Incompatible parameter type
604+ expected_op_counts = exp_counts ,
604605 )
605606
606607 def test_fusion_for_forked_transposes (self ):
607- # Create a graph with transpose -> quant -> transpose.
608+ # Create a graph with
609+ # transpose -> quant -> transpose.
610+ # -> quant -> transpose.
611+ # -> quant -> transpose.
608612 builder = GraphBuilder ()
609613 x = builder .placeholder ("x" , torch .randn (2 , 3 , 4 , dtype = torch .float32 ))
610614 transpose_node = builder .call_operator (
@@ -634,8 +638,8 @@ def test_fusion_for_forked_transposes(self):
634638 },
635639 )
636640
637- # Fuse the all the transpose ops.
638- gm_after_pass = FuseTransposeOpPairsPass ()(gm ).graph_module
641+ # Fuse all the transpose ops.
642+ gm_after_pass = FuseTransposeOrPermuteOpPairsPass ()(gm ).graph_module
639643 self .check_op_counts (
640644 gm_after_pass ,
641645 expected_op_counts = {
0 commit comments