1414import torch
1515from executorch .backends .cadence .aot import compiler
1616from executorch .backends .cadence .aot .fuse_ops import (
17+ FuseCascadedViewOps ,
1718 FuseFullThenReshapePass ,
19+ FuseMMWithAdd ,
1820 FuseMulScalarIntoDequantPass ,
1921 FuseMulTensorIntoDequantPass ,
2022 FuseQuantDequantToRequantizePass ,
@@ -39,113 +41,133 @@ def check_op_counts(
3941
4042
4143class TestFusionPasses (TestFusionPassesBase ):
42- def test_addmm_fusion (self ):
43- class AddmmFeasible1 (torch .nn .Module ):
44- def forward (self , x , y , z ):
45- t1 = torch .mm (x , y )
46- return torch .add (t1 , z )
47-
48- x = torch .randn (3 , 5 )
49- y = torch .randn (5 , 6 )
50- z = torch .randn (6 )
51-
52- graph_module = (
53- compiler .export_to_cadence (AddmmFeasible1 (), (x , y , z ))
54- .exported_program ()
55- .graph_module
44+ def test_fuse_mm_with_add (self ):
45+ builder = GraphBuilder ()
46+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
47+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
48+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
49+ mm = builder .call_operator (
50+ op = exir_ops .edge .aten .mm .default ,
51+ args = (x , y ),
52+ )
53+ output = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
54+ builder .output ([output ])
55+ original_graph = builder .get_graph_module ()
56+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
57+ converted_graph .graph .eliminate_dead_code ()
58+ self .assertEqual (
59+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
5660 )
57- graph_module .graph .eliminate_dead_code ()
58-
59- # Assert that mm and add were fused to addmm
60- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
61- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
62- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
63-
64- class AddmmFeasible2 (torch .nn .Module ):
65- def forward (self , x , y , z ):
66- t1 = y .view ((8 , 6 ))
67- t2 = torch .mm (x , t1 )
68- t3 = t2 .view ((2 , 2 , 6 ))
69- return torch .add (t3 , z )
70-
71- x = torch .randn (4 , 8 )
72- y = torch .randn (2 , 4 , 6 )
73- z = torch .randn (6 )
61+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
62+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
7463
75- graph_module = (
76- compiler .export_to_cadence (AddmmFeasible2 (), (x , y , z ))
77- .exported_program ()
78- .graph_module
64+ def test_fuse_view_mm_view_add (self ):
65+ builder = GraphBuilder ()
66+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
67+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
68+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
69+ y_view = builder .call_operator (
70+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
7971 )
80- graph_module .graph .eliminate_dead_code ()
81- # Assert that mm and add were fused to addmm
82- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
83- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
84- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
85-
86- # Bias is a singleton value, broadcastable to output of mm
87- class AddmmFeasible3 (torch .nn .Module ):
88- def forward (self , x , y ):
89- t1 = torch .mm (x , y )
90- return torch .add (t1 , torch .ones (1 ))
91-
92- x = torch .randn (3 , 5 )
93- y = torch .randn (5 , 6 )
94-
95- graph_module = (
96- compiler .export_to_cadence (AddmmFeasible3 (), (x , y ))
97- .exported_program ()
98- .graph_module
72+ mm = builder .call_operator (
73+ op = exir_ops .edge .aten .mm .default ,
74+ args = (x , y_view ),
75+ )
76+ mm_view = builder .call_operator (
77+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
9978 )
100- graph_module .graph .eliminate_dead_code ()
101- # Assert that mm and add were fused to addmm
102- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
103- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
104- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
79+ output = builder .call_operator (
80+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
81+ )
82+ builder .output ([output ])
83+ original_graph = builder .get_graph_module ()
84+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
85+ converted_graph .graph .eliminate_dead_code ()
86+ self .assertEqual (
87+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
88+ )
89+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
90+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
10591
92+ def test_keep_view_mm_view_add (self ):
93+ builder = GraphBuilder ()
94+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
95+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
10696 # Bias is not broadcastable to output of mm
107- class AddmmInfeasible1 (torch .nn .Module ):
108- def forward (self , x , y , z ):
109- t1 = y .view ((8 , 6 ))
110- t2 = torch .mm (x , t1 )
111- t3 = t2 .view ((2 , 2 , 6 ))
112- return torch .add (t3 , z )
113-
114- x = torch .randn (4 , 8 )
115- y = torch .randn (2 , 4 , 6 )
116- z = torch .randn (2 , 2 , 1 )
117-
118- graph_module = (
119- compiler .export_to_cadence (AddmmInfeasible1 (), (x , y , z ))
120- .exported_program ()
121- .graph_module
97+ z = builder .placeholder ("z" , torch .randn (2 , 2 , 1 , dtype = torch .float32 ))
98+ y_view = builder .call_operator (
99+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
100+ )
101+ mm = builder .call_operator (
102+ op = exir_ops .edge .aten .mm .default ,
103+ args = (x , y_view ),
122104 )
123- graph_module .graph .eliminate_dead_code ()
105+ mm_view = builder .call_operator (
106+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
107+ )
108+ output = builder .call_operator (
109+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
110+ )
111+ builder .output ([output ])
112+ original_graph = builder .get_graph_module ()
113+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
114+ converted_graph .graph .eliminate_dead_code ()
124115 # Assert that mm and add were not fused to addmm, since z cannot be
125116 # broadcasted to the out of mm.
126- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 1 )
127-
128- # The add consuming the output of mm has more than one users.
129- class AddmmInfeasible2 (torch .nn .Module ):
130- def forward (self , x , y , z ):
131- t1 = torch .mm (x , y )
132- t2 = torch .add (t1 , z )
133- t3 = torch .add (t2 , z )
134- return torch .add (t2 , t3 )
117+ self .assertEqual (
118+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
119+ )
120+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
121+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 1 )
135122
136- x = torch .randn (3 , 5 )
137- y = torch .randn (5 , 6 )
138- z = torch .randn (6 )
123+ def test_fuse_mm_add_with_bias (self ):
124+ builder = GraphBuilder ()
125+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
126+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
127+ mm = builder .call_operator (
128+ op = exir_ops .edge .aten .mm .default ,
129+ args = (x , y ),
130+ )
131+ bias = builder .call_operator (op = exir_ops .edge .aten .full .default , args = ([1 ], 1 ))
132+ output = builder .call_operator (
133+ op = exir_ops .edge .aten .add .Tensor , args = (mm , bias )
134+ )
135+ builder .output ([output ])
136+ original_graph = builder .get_graph_module ()
137+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
138+ converted_graph .graph .eliminate_dead_code ()
139+ self .assertEqual (
140+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
141+ )
142+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
143+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
139144
140- graph_module = (
141- compiler .export_to_cadence (AddmmInfeasible2 (), (x , y , z ))
142- .exported_program ()
143- .graph_module
145+ def test_keep_mm_add_with_multiple_users (self ):
146+ builder = GraphBuilder ()
147+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
148+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
149+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
150+ mm = builder .call_operator (
151+ op = exir_ops .edge .aten .mm .default ,
152+ args = (x , y ),
153+ )
154+ # The add consuming the output of mm has more than one users.
155+ add1 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
156+ add2 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (add1 , z ))
157+ output = builder .call_operator (
158+ op = exir_ops .edge .aten .add .Tensor , args = (add1 , add2 )
144159 )
145- graph_module .graph .eliminate_dead_code ()
160+ builder .output ([output ])
161+ original_graph = builder .get_graph_module ()
162+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
163+ converted_graph .graph .eliminate_dead_code ()
146164 # Assert that mm and add were not fused to addmm, since add has multiple
147165 # users.
148- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 3 )
166+ self .assertEqual (
167+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
168+ )
169+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
170+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 3 )
149171
150172 # TODO(matthiascremon): enable that pass with new flow
151173 @torch .no_grad ()
@@ -184,63 +206,70 @@ def forward(self, x):
184206 )
185207
186208 def test_permute_transpose_fusion (self ):
187- class PermuteTranspose (torch .nn .Module ):
188- def forward (self , x ):
189- y = x .permute ((0 , 2 , 4 , 1 , 3 ))
190- return y .transpose (0 , 1 )
191-
192- x = torch .randn (3 , 1 , 3 , 1 , 4 )
193- graph_module = (
194- compiler .export_to_cadence (PermuteTranspose (), (x ,))
195- .exported_program ()
196- .graph_module
209+ builder = GraphBuilder ()
210+ x = builder .placeholder ("x" , torch .randn (3 , 1 , 3 , 1 , 4 , dtype = torch .float32 ))
211+ permute = builder .call_operator (
212+ op = exir_ops .edge .aten .permute_copy .default , args = (x , [0 , 2 , 4 , 1 , 3 ])
213+ )
214+ output = builder .call_operator (
215+ op = exir_ops .edge .aten .transpose_copy .int ,
216+ args = (permute , 0 , 1 ),
197217 )
198- graph_module .graph .eliminate_dead_code ()
218+ builder .output (output )
219+ original_graph = builder .get_graph_module ()
220+ # Question: This pass can not be applied because [0, 2, 4] != [2, 0, 4] in can_fuse_for_chain. Do I use the right pass?
221+ converted_graph = FuseTransposeOrPermuteOpPairsPass ()(
222+ original_graph
223+ ).graph_module
224+ converted_graph .graph .eliminate_dead_code ()
199225 # Assert that permute op was fused with transpose op
200226 self .assertEqual (
201- count_node (graph_module , exir_ops .edge .aten .permute_copy .default ), 1
227+ count_node (converted_graph , exir_ops .edge .aten .permute_copy .default ), 1
202228 )
203229 self .assertEqual (
204- count_node (graph_module , exir_ops .edge .aten .transpose_copy .int ), 0
230+ count_node (converted_graph , exir_ops .edge .aten .transpose_copy .int ), 0
205231 )
206232
207233 def test_view_fusion (self ):
208- class ViewFusion (torch .nn .Module ):
209- def forward (self , x ):
210- x = x .view ([1 , 8 , 15 ])
211- x = x .view ([1 , 1 , 120 ])
212- return x .view ([1 , 12 , 10 ])
213-
214- x = torch .randn (8 , 5 , 3 )
215- graph_module = (
216- compiler .export_to_cadence (ViewFusion (), (x ,))
217- .exported_program ()
218- .graph_module
234+ builder = GraphBuilder ()
235+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
236+ view1 = builder .call_operator (
237+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
238+ )
239+ view2 = builder .call_operator (
240+ op = exir_ops .edge .aten .view_copy .default , args = (view1 , [1 , 1 , 120 ])
241+ )
242+ output = builder .call_operator (
243+ op = exir_ops .edge .aten .view_copy .default , args = (view2 , [1 , 12 , 10 ])
219244 )
220- graph_module .graph .eliminate_dead_code ()
245+ builder .output (output )
246+ original_graph = builder .get_graph_module ()
247+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
248+ converted_graph .graph .eliminate_dead_code ()
221249 # Assert that only one view op remains
222250 self .assertEqual (
223- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 1
251+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 1
224252 )
225253
226254 def test_view_fusion_branched (self ):
227- class ViewFusion (torch .nn .Module ):
228- def forward (self , x ):
229- y = x .view ([1 , 8 , 15 ])
230- z = y .view ([1 , 1 , 120 ])
231- t = y .view ([120 , 1 , 1 ])
232- return z , t
233-
234- x = torch .randn (8 , 5 , 3 )
235- graph_module = (
236- compiler .export_to_cadence (ViewFusion (), (x ,))
237- .exported_program ()
238- .graph_module
255+ builder = GraphBuilder ()
256+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
257+ y = builder .call_operator (
258+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
259+ )
260+ z = builder .call_operator (
261+ op = exir_ops .edge .aten .view_copy .default , args = (y , [1 , 1 , 120 ])
239262 )
240- graph_module .graph .eliminate_dead_code ()
263+ t = builder .call_operator (
264+ op = exir_ops .edge .aten .view_copy .default , args = (y , [120 , 1 , 1 ])
265+ )
266+ builder .output ([z , t ])
267+ original_graph = builder .get_graph_module ()
268+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
269+ converted_graph .graph .eliminate_dead_code ()
241270 # z and t should be fused and y should be eliminated.
242271 self .assertEqual (
243- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 2
272+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 2
244273 )
245274
246275 def test_force_quant_dequant_fusion (self ):
0 commit comments