1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
1515import torch .nn as nn
16- import torch .nn .functional as F
17- from executorch .backends .cadence .aot import compiler
1816from executorch .backends .cadence .aot .compiler import export_to_edge
1917from executorch .backends .cadence .aot .fuse_ops import FuseQuantDequantToRequantizePass
2018from executorch .backends .cadence .aot .graph_builder import GraphBuilder
@@ -53,16 +51,15 @@ class TestRemoveOpsPasses(unittest.TestCase):
5351 )
5452 @torch .no_grad ()
5553 def test_remove_to_ops (self , shape : Tuple [int ]):
56- class M (torch .nn .Module ):
57- def forward (self , x : torch .Tensor ):
58- return exir_ops .edge .aten .to (x , dtype = torch .float32 )
59-
60- model = M ()
61- x = torch .randn (shape )
62- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
63- p = RemoveToOpsPass ()
64-
65- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
54+ builder = GraphBuilder ()
55+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
56+ x = builder .call_operator (
57+ op = exir_ops .edge .aten .to .dtype ,
58+ args = (x , torch .float32 ),
59+ )
60+ builder .output ([x ])
61+ original = builder .get_graph_module ()
62+ graph_after_passes = cast (PassResult , RemoveToOpsPass ()(original )).graph_module
6663
6764 self .assertEqual (
6865 count_node (graph_after_passes , exir_ops .edge .aten .to .dtype ),
@@ -83,31 +80,24 @@ def forward(self, x: torch.Tensor):
8380 )
8481 @torch .no_grad ()
8582 def test_remove_nop_add_op_pass (self , shape : Tuple [int ]):
86- class FullX (torch .nn .Module ):
87- def forward (self , t : torch .Tensor ):
88- return torch .add (torch .full (shape , 0 ), t )
89-
90- class FullY (torch .nn .Module ):
91- def forward (self , t : torch .Tensor ):
92- return torch .add (t , torch .full (shape , 0 ))
93-
94- model = FullX ()
95- t = torch .full (shape , 3 )
96- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
97-
98- p = RemoveNopAddOpPass ()
99-
100- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
101- self .assertEqual (
102- count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
103- 0 ,
104- )
105-
106- model = FullY ()
107- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
108-
109- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
110-
83+ builder = GraphBuilder ()
84+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
85+ zeros = builder .call_operator (
86+ op = exir_ops .edge .aten .full .default , args = (shape , 0 )
87+ )
88+ left_add = builder .call_operator (
89+ op = exir_ops .edge .aten .add .Tensor ,
90+ args = (zeros , x ),
91+ )
92+ right_add = builder .call_operator (
93+ op = exir_ops .edge .aten .add .Tensor ,
94+ args = (left_add , zeros ),
95+ )
96+ builder .output ([right_add ])
97+ original = builder .get_graph_module ()
98+ graph_after_passes = cast (
99+ PassResult , RemoveNopAddOpPass ()(original )
100+ ).graph_module
111101 self .assertEqual (
112102 count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
113103 0 ,
@@ -122,31 +112,24 @@ def forward(self, t: torch.Tensor):
122112 )
123113 @torch .no_grad ()
124114 def test_remove_nop_mul_op_pass (self , shape : Tuple [int ]):
125- class FullX (torch .nn .Module ):
126- def forward (self , t : torch .Tensor ):
127- return torch .mul (torch .full (shape , 0 ), t )
128-
129- class FullY (torch .nn .Module ):
130- def forward (self , t : torch .Tensor ):
131- return torch .mul (t , torch .full (shape , 0 ))
132-
133- model = FullX ()
134- t = torch .full (shape , 3 )
135- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
136-
137- p = RemoveNopMulOpPass ()
138-
139- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
140- self .assertEqual (
141- count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
142- 0 ,
143- )
144-
145- model = FullY ()
146- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
147-
148- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
149-
115+ builder = GraphBuilder ()
116+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
117+ zeros = builder .call_operator (
118+ op = exir_ops .edge .aten .full .default , args = (shape , 0 )
119+ )
120+ left_mul = builder .call_operator (
121+ op = exir_ops .edge .aten .mul .Tensor ,
122+ args = (zeros , x ),
123+ )
124+ right_mul = builder .call_operator (
125+ op = exir_ops .edge .aten .mul .Tensor ,
126+ args = (left_mul , zeros ),
127+ )
128+ builder .output ([right_mul ])
129+ original = builder .get_graph_module ()
130+ graph_after_passes = cast (
131+ PassResult , RemoveNopMulOpPass ()(original )
132+ ).graph_module
150133 self .assertEqual (
151134 count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
152135 0 ,
@@ -159,18 +142,16 @@ def forward(self, t: torch.Tensor):
159142 )
160143 @torch .no_grad ()
161144 def test_remove_alias_copy (self , shape : Tuple [int ]):
162- class M (torch .nn .Module ):
163- def forward (self , x : torch .Tensor ):
164- return exir_ops .edge .aten .alias_copy (x )
165-
166- model = M ()
167- x = torch .randn (shape )
168- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
169-
170- p = RemoveAliasCopyOpPass ()
171-
172- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
173-
145+ builder = GraphBuilder ()
146+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
147+ alias = builder .call_operator (
148+ op = exir_ops .edge .aten .alias_copy .default , args = (x ,)
149+ )
150+ builder .output ([alias ])
151+ original = builder .get_graph_module ()
152+ graph_after_passes = cast (
153+ PassResult , RemoveAliasCopyOpPass ()(original )
154+ ).graph_module
174155 self .assertEqual (
175156 count_node (graph_after_passes , exir_ops .edge .aten .alias_copy .default ),
176157 0 ,
@@ -183,19 +164,16 @@ def forward(self, x: torch.Tensor):
183164 )
184165 @torch .no_grad ()
185166 def test_remove_detach_copy (self , shape : Tuple [int ]):
186- # aten::detach is converted to aten::alias_copy after functionalization & decomposition.
187- class M (torch .nn .Module ):
188- def forward (self , x : torch .Tensor ):
189- return exir_ops .edge .aten .detach_copy (x )
190-
191- model = M ()
192- x = torch .randn (shape )
193- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
194-
195- p = RemoveDetachCopyPass ()
196-
197- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
198-
167+ builder = GraphBuilder ()
168+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
169+ detach = builder .call_operator (
170+ op = exir_ops .edge .aten .detach_copy .default , args = (x ,)
171+ )
172+ builder .output ([detach ])
173+ original = builder .get_graph_module ()
174+ graph_after_passes = cast (
175+ PassResult , RemoveDetachCopyPass ()(original )
176+ ).graph_module
199177 self .assertEqual (
200178 count_node (graph_after_passes , exir_ops .edge .aten .detach_copy .default ),
201179 0 ,
@@ -210,95 +188,51 @@ def forward(self, x: torch.Tensor):
210188 def test_remove_zero_sized_constant_pad_nd (
211189 self , shape : Tuple [int ], padding : Tuple [int ]
212190 ):
213- # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
214- class Padding (torch .nn .Module ):
215- def __init__ (self ):
216- super ().__init__ ()
217- self .padding = padding
218-
219- def forward (self , x : torch .Tensor ):
220- return F .pad (x , self .padding )
221-
222- model = Padding ()
223- x = torch .randn (shape )
224- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
225-
226- p = RemoveZeroSizedConstantPadNd ()
227-
228- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
229-
191+ builder = GraphBuilder ()
192+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
193+ pad = builder .call_operator (
194+ op = exir_ops .edge .aten .constant_pad_nd .default , args = (x , padding )
195+ )
196+ builder .output ([pad ])
197+ original = builder .get_graph_module ()
198+ graph_after_passes = cast (
199+ PassResult , RemoveZeroSizedConstantPadNd ()(original )
200+ ).graph_module
230201 self .assertEqual (
231202 count_node (graph_after_passes , exir_ops .edge .aten .constant_pad_nd .default ),
232203 0 ,
233204 )
234205
235206 def test_remove_expand (self ):
236- class Expand (torch .nn .Module ):
237- def forward (self , x ):
238- return torch .ops .aten .expand_copy (x , [2 , 3 , 5 ])
239-
240- x = torch .ones (2 , 3 , 5 )
241- p = RemoveNopExpandOpPass ()
242- graph_module = export_to_edge (Expand (), (x ,)).exported_program ().graph_module
243- graph_module = p (graph_module ).graph_module
244- # Assert that expand op is optimized away, since it is a nop
207+ builder = GraphBuilder ()
208+ x = builder .placeholder ("x" , torch .randn ([2 , 3 , 5 ], dtype = torch .float32 ))
209+ expand = builder .call_operator (
210+ op = exir_ops .edge .aten .expand_copy .default , args = (x , [2 , 3 , 5 ])
211+ )
212+ builder .output ([expand ])
213+ original = builder .get_graph_module ()
214+ graph_after_passes = cast (
215+ PassResult , RemoveNopExpandOpPass ()(original )
216+ ).graph_module
245217 self .assertEqual (
246- count_node (graph_module , exir_ops .edge .aten .expand_copy .default ), 0
218+ count_node (graph_after_passes , exir_ops .edge .aten .expand_copy .default ), 0
247219 )
248220
249221 def test_remove_zero_arg_cat (self ):
250- class Cat (torch .nn .Module ):
251- def forward (self , x , y ):
252- return torch .ops .aten .cat ((x , y ), 0 )
253-
254- x = torch .ones (1 , 0 , 3 , 5 )
255- y = torch .ones (2 , 0 , 3 , 5 )
256- graph_module = (
257- compiler .export_to_cadence (Cat (), (x , y )).exported_program ().graph_module
258- )
259- # Assert that cat op is optimized away, since it concatenates
260- # two zero-sized tensors
261- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 0 )
262-
263- def test_remove_single_arg_cat (self ):
264- class Cat (torch .nn .Module ):
265- def forward (self , x , y ):
266- z = torch .ones (0 , 5 )
267- # z is an empty tensor, and concatenation of x with z will
268- # be x. So we can safely eliminate the following cat op.
269- x1 = torch .ops .aten .cat ((x , z ))
270- x2 = torch .add (x1 , 2.4 , 3.1 )
271- y1 = torch .add (y , 1 , 2 )
272- return torch .add (x2 , y1 )
273-
274- x = torch .ones (3 , 5 )
275- y = torch .ones (3 , 5 )
276- graph_module = export_to_edge (Cat (), (x , y )).exported_program ().graph_module
277- new_graph_module = RemoveZeroSizedCatArgsPass ()(graph_module ).graph_module
278- new_graph_module .graph .eliminate_dead_code ()
279- # Assert that x1 is optimized away
280- self .assertEqual (count_node (new_graph_module , torch .ops .aten .cat .out ), 0 )
281-
282- def test_remove_zero_sized_cat (self ):
283- class Cat (torch .nn .Module ):
284- def __init__ (self , dim : int ):
285- super ().__init__ ()
286- self .dim = dim
287-
288- def forward (self , tensors ):
289- return torch .cat (tensors , self .dim )
290-
291- shapes , dim , dtype , _max = [(1 , 0 , 3 ), (2 , 0 , 3 )], 0 , torch .float32 , 127
292-
293- in_tensors = [(torch .rand (shape ) * _max ).to (dtype = dtype ) for shape in shapes ]
294-
295- model = Cat (dim )
296- graph_module = (
297- export_to_edge (model , (in_tensors ,)).exported_program ().graph_module
222+ builder = GraphBuilder ()
223+ x = builder .placeholder ("x" , torch .randn ([1 , 0 , 3 , 5 ], dtype = torch .float32 ))
224+ y = builder .placeholder ("y" , torch .randn ([2 , 0 , 3 , 5 ], dtype = torch .float32 ))
225+ concat = builder .call_operator (
226+ op = exir_ops .edge .aten .cat .default , args = ([x , y ], 0 )
227+ )
228+ builder .output ([concat ])
229+ original = builder .get_graph_module ()
230+ graph_after_passes = cast (
231+ PassResult , RemoveZeroSizedCatArgsPass ()(original )
232+ ).graph_module
233+ self .assertEqual (
234+ count_node (graph_after_passes , exir_ops .edge .aten .cat .default ), 0
298235 )
299- new_graph_module = RemoveZeroSizedCatArgsPass ()(graph_module ).graph_module
300- new_graph_module .graph .eliminate_dead_code ()
301- self .assertEqual (count_node (new_graph_module , torch .ops .aten .cat .out ), 0 )
302236
303237 def test_remove_clone (self ):
304238 class Clone (torch .nn .Module ):
0 commit comments