@@ -176,6 +176,20 @@ def get_inputs(self):
176176 return (torch .randn (2 , 2 , 4 , 4 ),)
177177
178178
179+ class DQConv2d (torch .nn .Module ):
180+ def __init__ (self ):
181+ super ().__init__ ()
182+ self .conv = torch .nn .Conv2d (3 , 10 , 3 )
183+ self .conv .weight .requires_grad = False
184+ self .conv .bias .requires_grad = False
185+
186+ def forward (self , x ):
187+ return self .conv (x )
188+
189+ def get_inputs (self ):
190+ return (torch .randn (1 , 3 , 8 , 8 ),)
191+
192+
179193class TestConv2d (unittest .TestCase ):
180194 def setUp (self ):
181195 torch ._dynamo .reset ()
@@ -230,12 +244,11 @@ def _test(
230244 .run_method_and_compare_outputs (qtol = 1 )
231245 )
232246
233- def _test_dq_conv2d (
247+ def _test_dq (
234248 self ,
235249 m : torch .nn .Module ,
236250 inputs ,
237251 dynamic_shapes ,
238- atol = 5e-02 ,
239252 ):
240253 quant_config = get_symmetric_quantization_config (
241254 is_per_channel = True ,
@@ -250,21 +263,15 @@ def _test_dq_conv2d(
250263 tester = Tester (m , inputs , dynamic_shapes = dynamic_shapes )
251264 tester .quantize (Quantize (quantization_config = quant_config ))
252265 tester .export ()
253-
254266 tester .check (["torch.ops.quantized_decomposed.choose_qparams" ])
255-
256267 tester .to_edge_transform_and_lower (
257268 ToEdgeTransformAndLower ([DynamicallyQuantizedPartitioner ])
258269 )
259-
260270 tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
261271 tester .check_not (["executorch_exir_dialects_edge__ops_aten_conv2d_default" ])
262-
263272 tester .to_executorch ()
264- # tester.serialize()
265- tester .serialize ().dump_artifact ("conv2d.pte" )
266-
267- tester .run_method_and_compare_outputs (atol = atol )
273+ tester .serialize ()
274+ tester .run_method_and_compare_outputs (qtol = 1 )
268275
269276 def test_fp16_conv2d (self ) -> None :
270277 for transpose in (True , False ):
@@ -743,30 +750,10 @@ def forward(self, x):
743750 .run_method_and_compare_outputs (qtol = 1 )
744751 )
745752
746- def test_dq_conv2d (self ) -> None :
747- class SimpleConv2d (torch .nn .Module ):
748- def __init__ (self ):
749- super ().__init__ ()
750- self .conv = torch .nn .Conv2d (
751- 3 ,
752- 10 ,
753- 3 ,
754- )
755- self .conv .weight .requires_grad = False
756- self .conv .bias .requires_grad = False
757-
758- def forward (self , x ):
759- return self .conv (x )
760-
761- def get_inputs (self ):
762- return (torch .randn (1 , 3 , 8 , 8 ),)
763-
764- model = SimpleConv2d ()
765- inputs = model .get_inputs ()
766-
767- self ._test_dq_conv2d (
753+ def test_qs8_dq_conv2d (self ) -> None :
754+ model = DQConv2d ()
755+ self ._test_dq (
768756 model ,
769- inputs ,
757+ model . get_inputs () ,
770758 dynamic_shapes = None ,
771- atol = 3.0 ,
772759 )
0 commit comments