1313import  executorch .backends .cadence .aot .ops_registrations   # noqa 
1414import  torch 
1515import  torch .nn  as  nn 
16- import  torch .nn .functional  as  F 
17- from  executorch .backends .cadence .aot  import  compiler 
1816from  executorch .backends .cadence .aot .compiler  import  export_to_edge 
1917from  executorch .backends .cadence .aot .fuse_ops  import  FuseQuantDequantToRequantizePass 
2018from  executorch .backends .cadence .aot .graph_builder  import  GraphBuilder 
@@ -53,16 +51,15 @@ class TestRemoveOpsPasses(unittest.TestCase):
5351    ) 
5452    @torch .no_grad () 
5553    def  test_remove_to_ops (self , shape : Tuple [int ]):
56-         class  M (torch .nn .Module ):
57-             def  forward (self , x : torch .Tensor ):
58-                 return  exir_ops .edge .aten .to (x , dtype = torch .float32 )
59- 
60-         model  =  M ()
61-         x  =  torch .randn (shape )
62-         graph_module  =  export_to_edge (model , (x ,)).exported_program ().graph_module 
63-         p  =  RemoveToOpsPass ()
64- 
65-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
54+         builder  =  GraphBuilder ()
55+         x  =  builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
56+         x  =  builder .call_operator (
57+             op = exir_ops .edge .aten .to .dtype ,
58+             args = (x , torch .float32 ),
59+         )
60+         builder .output ([x ])
61+         original  =  builder .get_graph_module ()
62+         graph_after_passes  =  cast (PassResult , RemoveToOpsPass ()(original )).graph_module 
6663
6764        self .assertEqual (
6865            count_node (graph_after_passes , exir_ops .edge .aten .to .dtype ),
@@ -83,31 +80,24 @@ def forward(self, x: torch.Tensor):
8380    ) 
8481    @torch .no_grad () 
8582    def  test_remove_nop_add_op_pass (self , shape : Tuple [int ]):
86-         class  FullX (torch .nn .Module ):
87-             def  forward (self , t : torch .Tensor ):
88-                 return  torch .add (torch .full (shape , 0 ), t )
89- 
90-         class  FullY (torch .nn .Module ):
91-             def  forward (self , t : torch .Tensor ):
92-                 return  torch .add (t , torch .full (shape , 0 ))
93- 
94-         model  =  FullX ()
95-         t  =  torch .full (shape , 3 )
96-         graph_module  =  export_to_edge (model , (t ,)).exported_program ().graph_module 
97- 
98-         p  =  RemoveNopAddOpPass ()
99- 
100-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
101-         self .assertEqual (
102-             count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
103-             0 ,
104-         )
105- 
106-         model  =  FullY ()
107-         graph_module  =  export_to_edge (model , (t ,)).exported_program ().graph_module 
108- 
109-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
110- 
83+         builder  =  GraphBuilder ()
84+         x  =  builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
85+         zeros  =  builder .call_operator (
86+             op = exir_ops .edge .aten .full .default , args = (shape , 0 )
87+         )
88+         left_add  =  builder .call_operator (
89+             op = exir_ops .edge .aten .add .Tensor ,
90+             args = (zeros , x ),
91+         )
92+         right_add  =  builder .call_operator (
93+             op = exir_ops .edge .aten .add .Tensor ,
94+             args = (left_add , zeros ),
95+         )
96+         builder .output ([right_add ])
97+         original  =  builder .get_graph_module ()
98+         graph_after_passes  =  cast (
99+             PassResult , RemoveNopAddOpPass ()(original )
100+         ).graph_module 
111101        self .assertEqual (
112102            count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
113103            0 ,
@@ -122,31 +112,24 @@ def forward(self, t: torch.Tensor):
122112    ) 
123113    @torch .no_grad () 
124114    def  test_remove_nop_mul_op_pass (self , shape : Tuple [int ]):
125-         class  FullX (torch .nn .Module ):
126-             def  forward (self , t : torch .Tensor ):
127-                 return  torch .mul (torch .full (shape , 0 ), t )
128- 
129-         class  FullY (torch .nn .Module ):
130-             def  forward (self , t : torch .Tensor ):
131-                 return  torch .mul (t , torch .full (shape , 0 ))
132- 
133-         model  =  FullX ()
134-         t  =  torch .full (shape , 3 )
135-         graph_module  =  export_to_edge (model , (t ,)).exported_program ().graph_module 
136- 
137-         p  =  RemoveNopMulOpPass ()
138- 
139-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
140-         self .assertEqual (
141-             count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
142-             0 ,
143-         )
144- 
145-         model  =  FullY ()
146-         graph_module  =  export_to_edge (model , (t ,)).exported_program ().graph_module 
147- 
148-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
149- 
115+         builder  =  GraphBuilder ()
116+         x  =  builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
117+         zeros  =  builder .call_operator (
118+             op = exir_ops .edge .aten .full .default , args = (shape , 0 )
119+         )
120+         left_mul  =  builder .call_operator (
121+             op = exir_ops .edge .aten .mul .Tensor ,
122+             args = (zeros , x ),
123+         )
124+         right_mul  =  builder .call_operator (
125+             op = exir_ops .edge .aten .mul .Tensor ,
126+             args = (left_mul , zeros ),
127+         )
128+         builder .output ([right_mul ])
129+         original  =  builder .get_graph_module ()
130+         graph_after_passes  =  cast (
131+             PassResult , RemoveNopMulOpPass ()(original )
132+         ).graph_module 
150133        self .assertEqual (
151134            count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
152135            0 ,
@@ -159,18 +142,16 @@ def forward(self, t: torch.Tensor):
159142    ) 
160143    @torch .no_grad () 
161144    def  test_remove_alias_copy (self , shape : Tuple [int ]):
162-         class  M (torch .nn .Module ):
163-             def  forward (self , x : torch .Tensor ):
164-                 return  exir_ops .edge .aten .alias_copy (x )
165- 
166-         model  =  M ()
167-         x  =  torch .randn (shape )
168-         graph_module  =  export_to_edge (model , (x ,)).exported_program ().graph_module 
169- 
170-         p  =  RemoveAliasCopyOpPass ()
171- 
172-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
173- 
145+         builder  =  GraphBuilder ()
146+         x  =  builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
147+         alias  =  builder .call_operator (
148+             op = exir_ops .edge .aten .alias_copy .default , args = (x ,)
149+         )
150+         builder .output ([alias ])
151+         original  =  builder .get_graph_module ()
152+         graph_after_passes  =  cast (
153+             PassResult , RemoveAliasCopyOpPass ()(original )
154+         ).graph_module 
174155        self .assertEqual (
175156            count_node (graph_after_passes , exir_ops .edge .aten .alias_copy .default ),
176157            0 ,
@@ -183,19 +164,16 @@ def forward(self, x: torch.Tensor):
183164    ) 
184165    @torch .no_grad () 
185166    def  test_remove_detach_copy (self , shape : Tuple [int ]):
186-         # aten::detach is converted to aten::alias_copy after functionalization & decomposition. 
187-         class  M (torch .nn .Module ):
188-             def  forward (self , x : torch .Tensor ):
189-                 return  exir_ops .edge .aten .detach_copy (x )
190- 
191-         model  =  M ()
192-         x  =  torch .randn (shape )
193-         graph_module  =  export_to_edge (model , (x ,)).exported_program ().graph_module 
194- 
195-         p  =  RemoveDetachCopyPass ()
196- 
197-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
198- 
167+         builder  =  GraphBuilder ()
168+         x  =  builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
169+         detach  =  builder .call_operator (
170+             op = exir_ops .edge .aten .detach_copy .default , args = (x ,)
171+         )
172+         builder .output ([detach ])
173+         original  =  builder .get_graph_module ()
174+         graph_after_passes  =  cast (
175+             PassResult , RemoveDetachCopyPass ()(original )
176+         ).graph_module 
199177        self .assertEqual (
200178            count_node (graph_after_passes , exir_ops .edge .aten .detach_copy .default ),
201179            0 ,
@@ -210,95 +188,51 @@ def forward(self, x: torch.Tensor):
210188    def  test_remove_zero_sized_constant_pad_nd (
211189        self , shape : Tuple [int ], padding : Tuple [int ]
212190    ):
213-         # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. 
214-         class  Padding (torch .nn .Module ):
215-             def  __init__ (self ):
216-                 super ().__init__ ()
217-                 self .padding  =  padding 
218- 
219-             def  forward (self , x : torch .Tensor ):
220-                 return  F .pad (x , self .padding )
221- 
222-         model  =  Padding ()
223-         x  =  torch .randn (shape )
224-         graph_module  =  export_to_edge (model , (x ,)).exported_program ().graph_module 
225- 
226-         p  =  RemoveZeroSizedConstantPadNd ()
227- 
228-         graph_after_passes  =  cast (PassResult , p (graph_module )).graph_module 
229- 
191+         builder  =  GraphBuilder ()
192+         x  =  builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
193+         pad  =  builder .call_operator (
194+             op = exir_ops .edge .aten .constant_pad_nd .default , args = (x , padding )
195+         )
196+         builder .output ([pad ])
197+         original  =  builder .get_graph_module ()
198+         graph_after_passes  =  cast (
199+             PassResult , RemoveZeroSizedConstantPadNd ()(original )
200+         ).graph_module 
230201        self .assertEqual (
231202            count_node (graph_after_passes , exir_ops .edge .aten .constant_pad_nd .default ),
232203            0 ,
233204        )
234205
235206    def  test_remove_expand (self ):
236-         class  Expand (torch .nn .Module ):
237-             def  forward (self , x ):
238-                 return  torch .ops .aten .expand_copy (x , [2 , 3 , 5 ])
239- 
240-         x  =  torch .ones (2 , 3 , 5 )
241-         p  =  RemoveNopExpandOpPass ()
242-         graph_module  =  export_to_edge (Expand (), (x ,)).exported_program ().graph_module 
243-         graph_module  =  p (graph_module ).graph_module 
244-         # Assert that expand op is optimized away, since it is a nop 
207+         builder  =  GraphBuilder ()
208+         x  =  builder .placeholder ("x" , torch .randn ([2 , 3 , 5 ], dtype = torch .float32 ))
209+         expand  =  builder .call_operator (
210+             op = exir_ops .edge .aten .expand_copy .default , args = (x , [2 , 3 , 5 ])
211+         )
212+         builder .output ([expand ])
213+         original  =  builder .get_graph_module ()
214+         graph_after_passes  =  cast (
215+             PassResult , RemoveNopExpandOpPass ()(original )
216+         ).graph_module 
245217        self .assertEqual (
246-             count_node (graph_module , exir_ops .edge .aten .expand_copy .default ), 0 
218+             count_node (graph_after_passes , exir_ops .edge .aten .expand_copy .default ), 0 
247219        )
248220
249221    def  test_remove_zero_arg_cat (self ):
250-         class  Cat (torch .nn .Module ):
251-             def  forward (self , x , y ):
252-                 return  torch .ops .aten .cat ((x , y ), 0 )
253- 
254-         x  =  torch .ones (1 , 0 , 3 , 5 )
255-         y  =  torch .ones (2 , 0 , 3 , 5 )
256-         graph_module  =  (
257-             compiler .export_to_cadence (Cat (), (x , y )).exported_program ().graph_module 
258-         )
259-         # Assert that cat op is optimized away, since it concatenates 
260-         # two zero-sized tensors 
261-         self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 0 )
262- 
263-     def  test_remove_single_arg_cat (self ):
264-         class  Cat (torch .nn .Module ):
265-             def  forward (self , x , y ):
266-                 z  =  torch .ones (0 , 5 )
267-                 # z is an empty tensor, and concatenation of x with z will 
268-                 # be x. So we can safely eliminate the following cat op. 
269-                 x1  =  torch .ops .aten .cat ((x , z ))
270-                 x2  =  torch .add (x1 , 2.4 , 3.1 )
271-                 y1  =  torch .add (y , 1 , 2 )
272-                 return  torch .add (x2 , y1 )
273- 
274-         x  =  torch .ones (3 , 5 )
275-         y  =  torch .ones (3 , 5 )
276-         graph_module  =  export_to_edge (Cat (), (x , y )).exported_program ().graph_module 
277-         new_graph_module  =  RemoveZeroSizedCatArgsPass ()(graph_module ).graph_module 
278-         new_graph_module .graph .eliminate_dead_code ()
279-         # Assert that x1 is optimized away 
280-         self .assertEqual (count_node (new_graph_module , torch .ops .aten .cat .out ), 0 )
281- 
282-     def  test_remove_zero_sized_cat (self ):
283-         class  Cat (torch .nn .Module ):
284-             def  __init__ (self , dim : int ):
285-                 super ().__init__ ()
286-                 self .dim  =  dim 
287- 
288-             def  forward (self , tensors ):
289-                 return  torch .cat (tensors , self .dim )
290- 
291-         shapes , dim , dtype , _max  =  [(1 , 0 , 3 ), (2 , 0 , 3 )], 0 , torch .float32 , 127 
292- 
293-         in_tensors  =  [(torch .rand (shape ) *  _max ).to (dtype = dtype ) for  shape  in  shapes ]
294- 
295-         model  =  Cat (dim )
296-         graph_module  =  (
297-             export_to_edge (model , (in_tensors ,)).exported_program ().graph_module 
222+         builder  =  GraphBuilder ()
223+         x  =  builder .placeholder ("x" , torch .randn ([1 , 0 , 3 , 5 ], dtype = torch .float32 ))
224+         y  =  builder .placeholder ("y" , torch .randn ([2 , 0 , 3 , 5 ], dtype = torch .float32 ))
225+         concat  =  builder .call_operator (
226+             op = exir_ops .edge .aten .cat .default , args = ([x , y ], 0 )
227+         )
228+         builder .output ([concat ])
229+         original  =  builder .get_graph_module ()
230+         graph_after_passes  =  cast (
231+             PassResult , RemoveZeroSizedCatArgsPass ()(original )
232+         ).graph_module 
233+         self .assertEqual (
234+             count_node (graph_after_passes , exir_ops .edge .aten .cat .default ), 0 
298235        )
299-         new_graph_module  =  RemoveZeroSizedCatArgsPass ()(graph_module ).graph_module 
300-         new_graph_module .graph .eliminate_dead_code ()
301-         self .assertEqual (count_node (new_graph_module , torch .ops .aten .cat .out ), 0 )
302236
303237    def  test_remove_clone (self ):
304238        class  Clone (torch .nn .Module ):
0 commit comments