@@ -44,20 +44,20 @@ def setUp(self):
4444 )
4545 dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
4646
47- # def test_fp32_channels_last_tagged_reshape_pass(self):
48- # for module, num_reshape in self.modules.items():
49- # (
50- # Tester(module, (torch.randn(1, 1, 6, 6),))
51- # .export()
52- # .to_edge()
53- # .run_passes(self.PassStage)
54- # .check_count(
55- # {
56- # self.to_copy_name: num_reshape,
57- # }
58- # )
59- # .run_method_and_compare_outputs()
60- # )
47+ def test_fp32_channels_last_tagged_reshape_pass (self ):
48+ for module , num_reshape in self .modules .items ():
49+ (
50+ Tester (module , (torch .randn (1 , 1 , 6 , 6 ),))
51+ .export ()
52+ .to_edge ()
53+ .run_passes (self .PassStage )
54+ .check_count (
55+ {
56+ self .to_copy_name : num_reshape ,
57+ }
58+ )
59+ .run_method_and_compare_outputs ()
60+ )
6161
6262 class LinearConv (torch .nn .Module ):
6363 def __init__ (self ):
@@ -141,26 +141,26 @@ def test_nchw_input_on_nhwc_op(self):
141141
142142 tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
143143
144- # def test_qs8_channels_last_tagged_reshape_pass(self):
145- # for module, num_reshape in self.modules.items():
146- # (
147- # Tester(module, (torch.randn(1, 1, 6, 6),))
148- # .quantize()
149- # .export()
150- # .to_edge()
151- # .run_passes(self.PassStage)
152- # .check(
153- # [
154- # self.quant_name,
155- # self.dequant_name,
156- # self.to_copy_name,
157- # self.quant_name,
158- # self.dequant_name,
159- # ]
160- # * num_reshape
161- # )
162- # .run_method_and_compare_outputs()
163- # )
144+ def test_qs8_channels_last_tagged_reshape_pass (self ):
145+ for module , num_reshape in self .modules .items ():
146+ (
147+ Tester (module , (torch .randn (1 , 1 , 6 , 6 ),))
148+ .quantize ()
149+ .export ()
150+ .to_edge ()
151+ .run_passes (self .PassStage )
152+ .check (
153+ [
154+ self .quant_name ,
155+ self .dequant_name ,
156+ self .to_copy_name ,
157+ self .quant_name ,
158+ self .dequant_name ,
159+ ]
160+ * num_reshape
161+ )
162+ .run_method_and_compare_outputs ()
163+ )
164164
165165 class ConvRelu (torch .nn .Module ):
166166 def __init__ (self ):
@@ -171,39 +171,39 @@ def __init__(self):
171171 def forward (self , x ):
172172 return self .relu (self .conv (x ))
173173
174- # def test_fp32_channels_last_tagged_reshape_pass_conv_relu(self):
175- # (
176- # Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
177- # .export()
178- # .to_edge()
179- # .run_passes(self.PassStage)
180- # .check(
181- # [self.to_copy_name, self.conv_name, self.relu_name, self.to_copy_name]
182- # )
183- # .run_method_and_compare_outputs()
184- # )
174+ def test_fp32_channels_last_tagged_reshape_pass_conv_relu (self ):
175+ (
176+ Tester (self .ConvRelu ().eval (), (torch .randn (1 , 1 , 6 , 6 ),))
177+ .export ()
178+ .to_edge ()
179+ .run_passes (self .PassStage )
180+ .check (
181+ [self .to_copy_name , self .conv_name , self .relu_name , self .to_copy_name ]
182+ )
183+ .run_method_and_compare_outputs ()
184+ )
185185
186- # def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self):
187- # (
188- # Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
189- # .quantize()
190- # .export()
191- # .to_edge()
192- # .run_passes(self.PassStage)
193- # .check(
194- # [
195- # self.to_copy_name,
196- # self.quant_name,
197- # self.dequant_name,
198- # self.conv_name,
199- # self.relu_name,
200- # self.quant_name,
201- # self.dequant_name,
202- # self.to_copy_name,
203- # ]
204- # )
205- # .run_method_and_compare_outputs()
206- # )
186+ def test_qs8_channels_last_tagged_reshape_pass_conv_relu (self ):
187+ (
188+ Tester (self .ConvRelu ().eval (), (torch .randn (1 , 1 , 6 , 6 ),))
189+ .quantize ()
190+ .export ()
191+ .to_edge ()
192+ .run_passes (self .PassStage )
193+ .check (
194+ [
195+ self .to_copy_name ,
196+ self .quant_name ,
197+ self .dequant_name ,
198+ self .conv_name ,
199+ self .relu_name ,
200+ self .quant_name ,
201+ self .dequant_name ,
202+ self .to_copy_name ,
203+ ]
204+ )
205+ .run_method_and_compare_outputs ()
206+ )
207207
208208 class Conv2dBnHardtanhMeanSequenceModule (torch .nn .Module ):
209209 def __init__ (self ):
@@ -278,28 +278,28 @@ def __init__(self):
278278 def forward (self , x ):
279279 return self .conv (x )
280280
281- # def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
282- # (
283- # Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
284- # .quantize(
285- # Quantize(
286- # quantization_config=get_symmetric_quantization_config(
287- # is_dynamic=True
288- # )
289- # )
290- # )
291- # .export()
292- # .to_edge()
293- # .run_passes(self.PassStage)
294- # .check(
295- # [
296- # self.to_copy_name,
297- # self.choose_qparams_name,
298- # self.dynamic_quant_name,
299- # self.dequant_name,
300- # self.conv_name,
301- # self.to_copy_name,
302- # ]
303- # )
304- # .run_method_and_compare_outputs()
305- # )
281+ def test_dq_conv2d_channels_last_tagged_reshape_pass (self ) -> None :
282+ (
283+ Tester (self .Conv2dDynamicQuant ().eval (), (torch .randn (1 , 3 , 8 , 8 ),))
284+ .quantize (
285+ Quantize (
286+ quantization_config = get_symmetric_quantization_config (
287+ is_dynamic = True
288+ )
289+ )
290+ )
291+ .export ()
292+ .to_edge ()
293+ .run_passes (self .PassStage )
294+ .check (
295+ [
296+ self .to_copy_name ,
297+ self .choose_qparams_name ,
298+ self .dynamic_quant_name ,
299+ self .dequant_name ,
300+ self .conv_name ,
301+ self .to_copy_name ,
302+ ]
303+ )
304+ .run_method_and_compare_outputs ()
305+ )
0 commit comments