1414import torch
1515from executorch .backends .cadence .aot import compiler
1616from executorch .backends .cadence .aot .fuse_ops import (
17+ FuseCascadedTransposeOrPermuteOps ,
18+ FuseCascadedViewOps ,
1719 FuseFullThenReshapePass ,
20+ FuseMMWithAdd ,
1821 FuseMulScalarIntoDequantPass ,
1922 FuseMulTensorIntoDequantPass ,
2023 FuseQuantDequantToRequantizePass ,
@@ -39,113 +42,133 @@ def check_op_counts(
3942
4043
4144class TestFusionPasses (TestFusionPassesBase ):
42- def test_addmm_fusion (self ):
43- class AddmmFeasible1 (torch .nn .Module ):
44- def forward (self , x , y , z ):
45- t1 = torch .mm (x , y )
46- return torch .add (t1 , z )
47-
48- x = torch .randn (3 , 5 )
49- y = torch .randn (5 , 6 )
50- z = torch .randn (6 )
51-
52- graph_module = (
53- compiler .export_to_cadence (AddmmFeasible1 (), (x , y , z ))
54- .exported_program ()
55- .graph_module
45+ def test_fuse_mm_with_add (self ):
46+ builder = GraphBuilder ()
47+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
48+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
49+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
50+ mm = builder .call_operator (
51+ op = exir_ops .edge .aten .mm .default ,
52+ args = (x , y ),
53+ )
54+ output = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
55+ builder .output ([output ])
56+ original_graph = builder .get_graph_module ()
57+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
58+ converted_graph .graph .eliminate_dead_code ()
59+ self .assertEqual (
60+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
5661 )
57- graph_module .graph .eliminate_dead_code ()
58-
59- # Assert that mm and add were fused to addmm
60- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
61- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
62- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
63-
64- class AddmmFeasible2 (torch .nn .Module ):
65- def forward (self , x , y , z ):
66- t1 = y .view ((8 , 6 ))
67- t2 = torch .mm (x , t1 )
68- t3 = t2 .view ((2 , 2 , 6 ))
69- return torch .add (t3 , z )
70-
71- x = torch .randn (4 , 8 )
72- y = torch .randn (2 , 4 , 6 )
73- z = torch .randn (6 )
62+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
63+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
7464
75- graph_module = (
76- compiler .export_to_cadence (AddmmFeasible2 (), (x , y , z ))
77- .exported_program ()
78- .graph_module
65+ def test_fuse_view_mm_view_add (self ):
66+ builder = GraphBuilder ()
67+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
68+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
69+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
70+ y_view = builder .call_operator (
71+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
7972 )
80- graph_module .graph .eliminate_dead_code ()
81- # Assert that mm and add were fused to addmm
82- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
83- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
84- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
85-
86- # Bias is a singleton value, broadcastable to output of mm
87- class AddmmFeasible3 (torch .nn .Module ):
88- def forward (self , x , y ):
89- t1 = torch .mm (x , y )
90- return torch .add (t1 , torch .ones (1 ))
91-
92- x = torch .randn (3 , 5 )
93- y = torch .randn (5 , 6 )
94-
95- graph_module = (
96- compiler .export_to_cadence (AddmmFeasible3 (), (x , y ))
97- .exported_program ()
98- .graph_module
73+ mm = builder .call_operator (
74+ op = exir_ops .edge .aten .mm .default ,
75+ args = (x , y_view ),
76+ )
77+ mm_view = builder .call_operator (
78+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
9979 )
100- graph_module .graph .eliminate_dead_code ()
101- # Assert that mm and add were fused to addmm
102- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
103- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
104- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
80+ output = builder .call_operator (
81+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
82+ )
83+ builder .output ([output ])
84+ original_graph = builder .get_graph_module ()
85+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
86+ converted_graph .graph .eliminate_dead_code ()
87+ self .assertEqual (
88+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
89+ )
90+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
91+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
10592
93+ def test_keep_view_mm_view_add (self ):
94+ builder = GraphBuilder ()
95+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
96+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
10697 # Bias is not broadcastable to output of mm
107- class AddmmInfeasible1 (torch .nn .Module ):
108- def forward (self , x , y , z ):
109- t1 = y .view ((8 , 6 ))
110- t2 = torch .mm (x , t1 )
111- t3 = t2 .view ((2 , 2 , 6 ))
112- return torch .add (t3 , z )
113-
114- x = torch .randn (4 , 8 )
115- y = torch .randn (2 , 4 , 6 )
116- z = torch .randn (2 , 2 , 1 )
117-
118- graph_module = (
119- compiler .export_to_cadence (AddmmInfeasible1 (), (x , y , z ))
120- .exported_program ()
121- .graph_module
98+ z = builder .placeholder ("z" , torch .randn (2 , 2 , 1 , dtype = torch .float32 ))
99+ y_view = builder .call_operator (
100+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
101+ )
102+ mm = builder .call_operator (
103+ op = exir_ops .edge .aten .mm .default ,
104+ args = (x , y_view ),
122105 )
123- graph_module .graph .eliminate_dead_code ()
106+ mm_view = builder .call_operator (
107+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
108+ )
109+ output = builder .call_operator (
110+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
111+ )
112+ builder .output ([output ])
113+ original_graph = builder .get_graph_module ()
114+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
115+ converted_graph .graph .eliminate_dead_code ()
124116 # Assert that mm and add were not fused to addmm, since z cannot be
125117 # broadcasted to the out of mm.
126- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 1 )
127-
128- # The add consuming the output of mm has more than one users.
129- class AddmmInfeasible2 (torch .nn .Module ):
130- def forward (self , x , y , z ):
131- t1 = torch .mm (x , y )
132- t2 = torch .add (t1 , z )
133- t3 = torch .add (t2 , z )
134- return torch .add (t2 , t3 )
118+ self .assertEqual (
119+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
120+ )
121+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
122+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 1 )
135123
136- x = torch .randn (3 , 5 )
137- y = torch .randn (5 , 6 )
138- z = torch .randn (6 )
124+ def test_fuse_mm_add_with_bias (self ):
125+ builder = GraphBuilder ()
126+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
127+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
128+ mm = builder .call_operator (
129+ op = exir_ops .edge .aten .mm .default ,
130+ args = (x , y ),
131+ )
132+ bias = builder .call_operator (op = exir_ops .edge .aten .full .default , args = ([1 ], 1 ))
133+ output = builder .call_operator (
134+ op = exir_ops .edge .aten .add .Tensor , args = (mm , bias )
135+ )
136+ builder .output ([output ])
137+ original_graph = builder .get_graph_module ()
138+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
139+ converted_graph .graph .eliminate_dead_code ()
140+ self .assertEqual (
141+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
142+ )
143+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
144+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
139145
140- graph_module = (
141- compiler .export_to_cadence (AddmmInfeasible2 (), (x , y , z ))
142- .exported_program ()
143- .graph_module
146+ def test_keep_mm_add_with_multiple_users (self ):
147+ builder = GraphBuilder ()
148+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
149+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
150+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
151+ mm = builder .call_operator (
152+ op = exir_ops .edge .aten .mm .default ,
153+ args = (x , y ),
154+ )
155+ # The add consuming the output of mm has more than one users.
156+ add1 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
157+ add2 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (add1 , z ))
158+ output = builder .call_operator (
159+ op = exir_ops .edge .aten .add .Tensor , args = (add1 , add2 )
144160 )
145- graph_module .graph .eliminate_dead_code ()
161+ builder .output ([output ])
162+ original_graph = builder .get_graph_module ()
163+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
164+ converted_graph .graph .eliminate_dead_code ()
146165 # Assert that mm and add were not fused to addmm, since add has multiple
147166 # users.
148- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 3 )
167+ self .assertEqual (
168+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
169+ )
170+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
171+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 3 )
149172
150173 # TODO(matthiascremon): enable that pass with new flow
151174 @torch .no_grad ()
@@ -184,63 +207,69 @@ def forward(self, x):
184207 )
185208
186209 def test_permute_transpose_fusion (self ):
187- class PermuteTranspose (torch .nn .Module ):
188- def forward (self , x ):
189- y = x .permute ((0 , 2 , 4 , 1 , 3 ))
190- return y .transpose (0 , 1 )
191-
192- x = torch .randn (3 , 1 , 3 , 1 , 4 )
193- graph_module = (
194- compiler .export_to_cadence (PermuteTranspose (), (x ,))
195- .exported_program ()
196- .graph_module
210+ builder = GraphBuilder ()
211+ x = builder .placeholder ("x" , torch .randn (3 , 1 , 3 , 1 , 4 , dtype = torch .float32 ))
212+ permute = builder .call_operator (
213+ op = exir_ops .edge .aten .permute_copy .default , args = (x , [0 , 2 , 4 , 1 , 3 ])
214+ )
215+ output = builder .call_operator (
216+ op = exir_ops .edge .aten .transpose_copy .int ,
217+ args = (permute , 1 , 0 ),
197218 )
198- graph_module .graph .eliminate_dead_code ()
219+ builder .output (output )
220+ original_graph = builder .get_graph_module ()
221+ converted_graph = FuseCascadedTransposeOrPermuteOps ()(
222+ original_graph
223+ ).graph_module
224+ converted_graph .graph .eliminate_dead_code ()
199225 # Assert that permute op was fused with transpose op
200226 self .assertEqual (
201- count_node (graph_module , exir_ops .edge .aten .permute_copy .default ), 1
227+ count_node (converted_graph , exir_ops .edge .aten .permute_copy .default ), 1
202228 )
203229 self .assertEqual (
204- count_node (graph_module , exir_ops .edge .aten .transpose_copy .int ), 0
230+ count_node (converted_graph , exir_ops .edge .aten .transpose_copy .int ), 0
205231 )
206232
207233 def test_view_fusion (self ):
208- class ViewFusion (torch .nn .Module ):
209- def forward (self , x ):
210- x = x .view ([1 , 8 , 15 ])
211- x = x .view ([1 , 1 , 120 ])
212- return x .view ([1 , 12 , 10 ])
213-
214- x = torch .randn (8 , 5 , 3 )
215- graph_module = (
216- compiler .export_to_cadence (ViewFusion (), (x ,))
217- .exported_program ()
218- .graph_module
234+ builder = GraphBuilder ()
235+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
236+ view1 = builder .call_operator (
237+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
238+ )
239+ view2 = builder .call_operator (
240+ op = exir_ops .edge .aten .view_copy .default , args = (view1 , [1 , 1 , 120 ])
241+ )
242+ output = builder .call_operator (
243+ op = exir_ops .edge .aten .view_copy .default , args = (view2 , [1 , 12 , 10 ])
219244 )
220- graph_module .graph .eliminate_dead_code ()
245+ builder .output (output )
246+ original_graph = builder .get_graph_module ()
247+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
248+ converted_graph .graph .eliminate_dead_code ()
221249 # Assert that only one view op remains
222250 self .assertEqual (
223- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 1
251+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 1
224252 )
225253
226254 def test_view_fusion_branched (self ):
227- class ViewFusion (torch .nn .Module ):
228- def forward (self , x ):
229- y = x .view ([1 , 8 , 15 ])
230- z = y .view ([1 , 1 , 120 ])
231- t = y .view ([120 , 1 , 1 ])
232- return z , t
233-
234- x = torch .randn (8 , 5 , 3 )
235- graph_module = (
236- compiler .export_to_cadence (ViewFusion (), (x ,))
237- .exported_program ()
238- .graph_module
255+ builder = GraphBuilder ()
256+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
257+ y = builder .call_operator (
258+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
259+ )
260+ z = builder .call_operator (
261+ op = exir_ops .edge .aten .view_copy .default , args = (y , [1 , 1 , 120 ])
239262 )
240- graph_module .graph .eliminate_dead_code ()
263+ t = builder .call_operator (
264+ op = exir_ops .edge .aten .view_copy .default , args = (y , [120 , 1 , 1 ])
265+ )
266+ builder .output ([z , t ])
267+ original_graph = builder .get_graph_module ()
268+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
269+ converted_graph .graph .eliminate_dead_code ()
241270 # z and t should be fused and y should be eliminated.
242271 self .assertEqual (
243- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 2
272+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 2
244273 )
245274
246275 def test_force_quant_dequant_fusion (self ):
0 commit comments