1010from executorch .backends .xnnpack ._passes .channels_last_tagged_reshape_pass import (
1111 ChannelsLastTaggedReshapePass ,
1212)
13+ from executorch .backends .xnnpack .quantizer .xnnpack_quantizer import (
14+ get_symmetric_quantization_config ,
15+ )
1316from executorch .backends .xnnpack .test .test_xnnpack_utils_classes import (
1417 OpSequencesAddConv2d ,
1518)
16- from executorch .backends .xnnpack .test .tester import RunPasses , Tester
19+ from executorch .backends .xnnpack .test .tester import Quantize , RunPasses , Tester
1720
1821
1922class TestChannelsLastTaggedReshapePass (unittest .TestCase ):
@@ -35,6 +38,10 @@ def setUp(self):
3538 dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
3639 conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default"
3740 relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default"
41+ choose_qparams_name = (
42+ "executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
43+ )
44+ dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
3845
3946 def test_fp32_channels_last_tagged_reshape_pass (self ):
4047 for module , num_reshape in self .modules .items ():
@@ -179,3 +186,37 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
179186 )
180187 .run_method_and_compare_outputs ()
181188 )
189+
190+ class Conv2dDynamicQuant (torch .nn .Module ):
191+ def __init__ (self ):
192+ super ().__init__ ()
193+ self .conv = torch .nn .Conv2d (3 , 10 , 3 )
194+
195+ def forward (self , x ):
196+ return self .conv (x )
197+
198+ def test_dq_conv2d_channels_last_tagged_reshape_pass (self ) -> None :
199+ (
200+ Tester (self .Conv2dDynamicQuant ().eval (), (torch .randn (1 , 3 , 8 , 8 ),))
201+ .quantize (
202+ Quantize (
203+ quantization_config = get_symmetric_quantization_config (
204+ is_dynamic = True
205+ )
206+ )
207+ )
208+ .export ()
209+ .to_edge ()
210+ .run_passes (self .PassStage )
211+ .check (
212+ [
213+ self .to_copy_name ,
214+ self .choose_qparams_name ,
215+ self .dynamic_quant_name ,
216+ self .dequant_name ,
217+ self .conv_name ,
218+ self .to_copy_name ,
219+ ]
220+ )
221+ .run_method_and_compare_outputs ()
222+ )
0 commit comments