@@ -43,17 +43,20 @@ def forward(self, x):
4343 return out1 , out2
4444
4545 def _test_add (self , inputs ):
46- (
47- Tester (self .Add (), inputs )
48- .export ()
49- .check_count ({"torch.ops.aten.add.Tensor" : 4 })
50- .to_edge_transform_and_lower ()
51- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
52- .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
53- .to_executorch ()
54- .serialize ()
55- .run_method_and_compare_outputs ()
56- )
46+ for legacy in (True , False ):
47+ tester = Tester (self .Add (), inputs )
48+ tester .export ()
49+ tester .check_count ({"torch.ops.aten.add.Tensor" : 4 })
50+ if legacy :
51+ tester .to_edge ()
52+ tester .partition ()
53+ else :
54+ tester .to_edge_transform_and_lower ()
55+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
56+ tester .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
57+ tester .to_executorch ()
58+ tester .serialize ()
59+ tester .run_method_and_compare_outputs ()
5760
5861 def test_fp16_add (self ):
5962 inputs = (torch .randn (1 ).to (torch .float16 ), torch .randn (1 ).to (torch .float16 ))
@@ -65,95 +68,110 @@ def test_fp32_add(self):
6568
6669 def test_fp32_add_constant (self ):
6770 inputs = (torch .randn (4 , 4 , 4 ),)
68- (
69- Tester (self .AddConstant (torch .randn (4 , 4 , 4 )), inputs )
70- .export ()
71- .check_count ({"torch.ops.aten.add.Tensor" : 4 })
72- .to_edge_transform_and_lower ()
73- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
74- .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
75- .to_executorch ()
76- .serialize ()
77- .run_method_and_compare_outputs ()
78- )
71+ for legacy in (True , False ):
72+ tester = Tester (self .AddConstant (torch .randn (4 , 4 , 4 )), inputs )
73+ tester .export ()
74+ tester .check_count ({"torch.ops.aten.add.Tensor" : 4 })
75+ if legacy :
76+ tester .to_edge ()
77+ tester .partition ()
78+ else :
79+ tester .to_edge_transform_and_lower ()
80+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
81+ tester .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
82+ tester .to_executorch ()
83+ tester .serialize ()
84+ tester .run_method_and_compare_outputs ()
7985
8086 def test_qs8_add_constant (self ):
8187 inputs = (torch .randn (4 , 4 , 4 ),)
82- (
83- Tester (self .AddConstant (torch .randn (4 , 4 , 4 )), inputs )
84- .quantize ()
85- .export ()
86- .check_count ({"torch.ops.aten.add.Tensor" : 4 })
87- .to_edge_transform_and_lower ()
88- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
89- .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
90- .to_executorch ()
91- .serialize ()
92- .run_method_and_compare_outputs ()
93- )
88+ for legacy in (True , False ):
89+ tester = Tester (self .AddConstant (torch .randn (4 , 4 , 4 )), inputs )
90+ tester .quantize ()
91+ tester .export ()
92+ tester .check_count ({"torch.ops.aten.add.Tensor" : 4 })
93+ if legacy :
94+ tester .to_edge ()
95+ tester .partition ()
96+ else :
97+ tester .to_edge_transform_and_lower ()
98+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
99+ tester .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
100+ tester .to_executorch ()
101+ tester .serialize ()
102+ tester .run_method_and_compare_outputs ()
94103
95104 def test_qs8_add (self ):
96105 inputs = (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 4 ))
97- (
98- Tester (self .Add (), inputs )
99- .quantize ()
100- .export ()
101- .check_count ({"torch.ops.aten.add.Tensor" : 4 })
102- .check (["torch.ops.quantized_decomposed" ])
103- .to_edge_transform_and_lower ()
104- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
105- .check_not (
106+ for legacy in (True , False ):
107+ tester = Tester (self .Add (), inputs )
108+ tester .quantize ()
109+ tester .export ()
110+ tester .check_count ({"torch.ops.aten.add.Tensor" : 4 })
111+ tester .check (["torch.ops.quantized_decomposed" ])
112+ if legacy :
113+ tester .to_edge ()
114+ tester .partition ()
115+ else :
116+ tester .to_edge_transform_and_lower ()
117+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
118+ tester .check_not (
106119 [
107120 "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
108121 "torch.ops.quantized_decomposed" ,
109122 ]
110123 )
111- .to_executorch ()
112- .serialize ()
113- .run_method_and_compare_outputs ()
114- )
124+ tester .to_executorch ()
125+ tester .serialize ()
126+ tester .run_method_and_compare_outputs ()
115127
116128 def test_qs8_add2 (self ):
117129 inputs = (torch .randn (1 , 1 , 4 , 4 ),)
118- (
119- Tester (self .Add2 (), inputs )
120- .quantize ()
121- .export ()
122- .check_count ({"torch.ops.aten.add.Tensor" : 1 })
123- .check (["torch.ops.quantized_decomposed" ])
124- .to_edge_transform_and_lower ()
125- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
126- .check_not (
130+ for legacy in (True , False ):
131+ tester = Tester (self .Add2 (), inputs )
132+ tester .quantize ()
133+ tester .export ()
134+ tester .check_count ({"torch.ops.aten.add.Tensor" : 1 })
135+ tester .check (["torch.ops.quantized_decomposed" ])
136+ if legacy :
137+ tester .to_edge ()
138+ tester .partition ()
139+ else :
140+ tester .to_edge_transform_and_lower ()
141+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
142+ tester .check_not (
127143 [
128144 "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
129145 "torch.ops.quantized_decomposed" ,
130146 ]
131147 )
132- .to_executorch ()
133- .serialize ()
134- .run_method_and_compare_outputs ()
135- )
148+ tester .to_executorch ()
149+ tester .serialize ()
150+ tester .run_method_and_compare_outputs ()
136151
137152 def test_qs8_add3 (self ):
138153 inputs = (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 ))
139- (
140- Tester (self .Add (), inputs )
141- .quantize ()
142- .export ()
143- .check_count ({"torch.ops.aten.add.Tensor" : 4 })
144- .check (["torch.ops.quantized_decomposed" ])
145- .to_edge_transform_and_lower ()
146- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
147- .check_not (
154+ for legacy in (True , False ):
155+ tester = Tester (self .Add (), inputs )
156+ tester .quantize ()
157+ tester .export ()
158+ tester .check_count ({"torch.ops.aten.add.Tensor" : 4 })
159+ tester .check (["torch.ops.quantized_decomposed" ])
160+ if legacy :
161+ tester .to_edge ()
162+ tester .partition ()
163+ else :
164+ tester .to_edge_transform_and_lower ()
165+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
166+ tester .check_not (
148167 [
149168 "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
150169 "torch.ops.quantized_decomposed" ,
151170 ]
152171 )
153- .to_executorch ()
154- .serialize ()
155- .run_method_and_compare_outputs ()
156- )
172+ tester .to_executorch ()
173+ tester .serialize ()
174+ tester .run_method_and_compare_outputs ()
157175
158176 class AddRelu (torch .nn .Module ):
159177 def forward (self , x , y ):
@@ -162,35 +180,41 @@ def forward(self, x, y):
162180
163181 def test_fp32_add_relu (self ):
164182 inputs = (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 4 ))
165- (
166- Tester (self .AddRelu (), inputs )
167- .export ()
168- .check_count ({"torch.ops.aten.add.Tensor" : 1 })
169- .check_count ({"torch.ops.aten.relu.default" : 1 })
170- .to_edge_transform_and_lower ()
171- .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
172- .check_not (["executorch_exir_dialects_edge__ops_aten_relu_default" ])
173- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
174- .to_executorch ()
175- .serialize ()
176- .run_method_and_compare_outputs ()
177- )
183+ for legacy in (True , False ):
184+ tester = Tester (self .AddRelu (), inputs )
185+ tester .export ()
186+ tester .check_count ({"torch.ops.aten.add.Tensor" : 1 })
187+ tester .check_count ({"torch.ops.aten.relu.default" : 1 })
188+ if legacy :
189+ tester .to_edge ()
190+ tester .partition ()
191+ else :
192+ tester .to_edge_transform_and_lower ()
193+ tester .check_not (["executorch_exir_dialects_edge__ops_aten_add_Tensor" ])
194+ tester .check_not (["executorch_exir_dialects_edge__ops_aten_relu_default" ])
195+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
196+ tester .to_executorch ()
197+ tester .serialize ()
198+ tester .run_method_and_compare_outputs ()
178199
179200 def test_qs8_add_relu (self ):
180201 inputs = (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 4 ))
181- (
182- Tester (self .AddRelu (), inputs )
183- .quantize ()
184- .export ()
185- .check_count ({"torch.ops.aten.add.Tensor" : 1 })
186- .check_count ({"torch.ops.aten.relu.default" : 1 })
187- .check (["torch.ops.quantized_decomposed" ])
188- .to_edge_transform_and_lower ()
189- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
190- .to_executorch ()
191- .serialize ()
192- .run_method_and_compare_outputs ()
193- )
202+ for legacy in (True , False ):
203+ tester = Tester (self .AddRelu (), inputs )
204+ tester .quantize ()
205+ tester .export ()
206+ tester .check_count ({"torch.ops.aten.add.Tensor" : 1 })
207+ tester .check_count ({"torch.ops.aten.relu.default" : 1 })
208+ tester .check (["torch.ops.quantized_decomposed" ])
209+ if legacy :
210+ tester .to_edge ()
211+ tester .partition ()
212+ else :
213+ tester .to_edge_transform_and_lower ()
214+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
215+ tester .to_executorch ()
216+ tester .serialize ()
217+ tester .run_method_and_compare_outputs ()
194218
195219 def test_qs8_add_relu_seq (self ):
196220 class AddReLU (torch .nn .Module ):
@@ -220,17 +244,21 @@ def forward(self, x, z):
220244 ),
221245 )
222246
223- (
224- Tester (self .AddRelu (), inputs )
225- .quantize ()
226- .export ()
227- .check_count (
247+ for legacy in ( True , False ):
248+ tester = Tester (self .AddRelu (), inputs )
249+ tester .quantize ()
250+ tester .export ()
251+ tester .check_count (
228252 {"torch.ops.aten.add.Tensor" : 1 , "torch.ops.aten.relu.default" : 1 }
229253 )
230- .check (["torch.ops.quantized_decomposed" ])
231- .to_edge_transform_and_lower ()
232- .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
233- .to_executorch ()
234- .serialize ()
235- .run_method_and_compare_outputs ()
236- )
254+ tesster .check (["torch.ops.quantized_decomposed" ])
255+ if legacy :
256+ tester .to_edge ()
257+ tester .partition ()
258+ else :
259+ tester .to_edge_transform_and_lower ()
260+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
261+ tester .to_executorch ()
262+ tester .serialize ()
263+ tester .run_method_and_compare_outputs ()
264+
0 commit comments