1515
1616class FuseConsecutiveTranspose (ExportPass ):
1717 """
18- This pass fuses consecutive transpose / permute into one or none to reduce runtime
19- overhead.
20- To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node
21- by cloning transpose.
22- Example:
23- Before clone transpose:
24- relu -> permute1 ─> permute2
25- |──────> permute3
26-
27- After clone transpose:
28- relu ─> permute1 ──────> permute2
29- |───> permute4(new) ─> permute3
18+ This pass fuses consecutive transpose / permute into one to reduce runtime
19+ overhead
3020 """
3121
3222 def __init__ (self ):
@@ -37,81 +27,54 @@ def __init__(self):
3727 self .visited = set ()
3828 self .nodes = []
3929
40- def _clone_transpose (
41- self , graph_module : torch .fx .GraphModule
42- ) -> torch .fx .GraphModule :
43- graph = graph_module .graph
44- for n in graph_module .graph .nodes :
45- if n .target in self .op_map :
46- users = [user for user in list (n .users ) if user .target in self .op_map ]
47- if len (users ) > 1 :
48- for i in range (1 , len (users )):
49- with graph .inserting_after (n ):
50- clone_permute_node = graph .create_node (
51- "call_function" ,
52- exir_ops .edge .aten .permute_copy .default ,
53- (n .args [0 ], n .args [1 ]),
54- )
55- clone_permute_node .meta = n .meta
56- users [i ].replace_input_with (n , clone_permute_node )
57-
58- def _is_dispensable (self , axis_order ):
59- for index , value in enumerate (axis_order ):
60- if index != value :
61- return False
62- return True
63-
6430 def _traverse (self , node ):
6531 if node in self .visited or node .target not in self .op_map :
6632 return
6733
6834 self .nodes .append (node )
6935 self .visited .add (node )
7036 next_users = [n for n in list (node .users ) if n .target in self .op_map ]
71-
72- assert (
73- len (next_users ) <= 1
74- ), "Each permute node should have at most 1 permute output node after _clone_transpose"
7537 if not next_users :
7638 return
77- else :
39+
40+ if len (next_users ) == 1 :
7841 self ._traverse (list (node .users )[0 ])
42+ else :
43+ raise NotImplementedError (
44+ f"Check the node { node } , wich encounter mutilple permute output case"
45+ )
7946
8047 def _fuse (self , graph_module : torch .fx .GraphModule ) -> torch .fx .GraphModule :
8148 graph = graph_module .graph
8249 for n in graph_module .graph .nodes :
8350 self ._traverse (n )
8451 if len (self .nodes ) > 1 :
52+ permute_order = []
8553 input_node , output_node = self .nodes [0 ].args [0 ], self .nodes [- 1 ]
8654 input_shape = input_node .meta ["val" ].shape
8755 axis_order = torch .arange (len (input_shape )).tolist ()
8856 for node in self .nodes :
57+ permute_order .append (node .args [1 ])
8958 axis_order = [axis_order [i ] for i in node .args [1 ]]
90- # If axis order is just [0,1,2,3], we ignore permute node
91- if self ._is_dispensable (axis_order ):
92- for user in output_node .users .copy ():
93- user .replace_input_with (output_node , n .args [0 ])
94- else :
95- with graph .inserting_after (input_node ):
96- permute_op = exir_ops .edge .aten .permute_copy .default
97- permute_node = graph .create_node (
98- "call_function" , permute_op , (input_node , axis_order )
99- )
100- users = output_node .users .copy ()
101- for user in users :
102- user .replace_input_with (output_node , permute_node )
103-
104- # copy metadata
105- permute_node .meta = output_node .meta
106- # Without "qnn_permute", we might obtain wrong input shape
107- if [pn .meta .get (QCOM_INSERTED_PERMUTE ) for pn in self .nodes ]:
108- permute_node .meta [QCOM_INSERTED_PERMUTE ] = True
59+ with graph .inserting_after (input_node ):
60+ permute_op = exir_ops .edge .aten .permute_copy .default
61+ permute_node = graph .create_node (
62+ "call_function" , permute_op , (input_node , axis_order )
63+ )
64+ users = output_node .users .copy ()
65+ for user in users :
66+ user .replace_input_with (output_node , permute_node )
67+
68+ # copy metadata
69+ permute_node .meta = output_node .meta
70+ # Without "qnn_permute", we might obtain wrong input shape
71+ if [pn .meta .get (QCOM_INSERTED_PERMUTE ) for pn in self .nodes ]:
72+ permute_node .meta [QCOM_INSERTED_PERMUTE ] = True
10973
11074 # clear current stack
11175 self .nodes = []
11276
11377 def call (self , graph_module : torch .fx .GraphModule ):
114- self ._clone_transpose (graph_module )
11578 self ._fuse (graph_module )
11679 graph_module .recompile ()
11780 dead_code_elimination_pass (graph_module )
0 commit comments