@@ -173,12 +173,10 @@ def get_inputs(self):
173173 return (torch .randn (2 , 2 , 4 , 4 ),)
174174
175175
176- class Conv2dDynamicQuant (torch .nn .Module ):
176+ class Conv2dDQ (torch .nn .Module ):
177177 def __init__ (self ):
178178 super ().__init__ ()
179- self .conv = torch .nn .Conv2d (3 , 10 , 3 )
180- self .conv .weight .requires_grad = False
181- self .conv .bias .requires_grad = False
179+ self .conv = torch .nn .Conv2d (in_channels = 3 , out_channels = 10 , kernel_size = 3 )
182180
183181 def forward (self , x ):
184182 return self .conv (x )
@@ -187,6 +185,43 @@ def get_inputs(self):
187185 return (torch .randn (1 , 3 , 8 , 8 ),)
188186
189187
188+ class Conv2dDQSeq (torch .nn .Module ):
189+ def __init__ (self ):
190+ super ().__init__ ()
191+ self .first = torch .nn .Conv2d (
192+ in_channels = 3 , out_channels = 8 , kernel_size = 3 , padding = 1
193+ )
194+ self .second = torch .nn .Conv2d (
195+ in_channels = 8 , out_channels = 10 , kernel_size = 3 , padding = 1
196+ )
197+
198+ def forward (self , x ):
199+ y = self .first (x )
200+ return self .second (y )
201+
202+ def get_inputs (self ):
203+ return (torch .randn (1 , 3 , 8 , 8 ),)
204+
205+
206+ class Conv2dDQParallel (torch .nn .Module ):
207+ def __init__ (self ):
208+ super ().__init__ ()
209+ self .first = torch .nn .Conv2d (
210+ in_channels = 3 , out_channels = 8 , kernel_size = 3 , padding = 1
211+ )
212+ self .second = torch .nn .Conv2d (
213+ in_channels = 3 , out_channels = 10 , kernel_size = 3 , padding = 1
214+ )
215+
216+ def forward (self , x ):
217+ first = self .first (x )
218+ second = self .second (x )
219+ return first , second
220+
221+ def get_inputs (self ):
222+ return (torch .randn (1 , 3 , 8 , 8 ),)
223+
224+
190225class TestConv2d (unittest .TestCase ):
191226 def setUp (self ):
192227 torch ._dynamo .reset ()
@@ -244,8 +279,8 @@ def _test(
244279 def _test_dq (
245280 self ,
246281 m : torch .nn .Module ,
247- inputs ,
248- dynamic_shapes ,
282+ conv_count = 1 ,
283+ dynamic_shapes = None ,
249284 ):
250285 quant_config = get_symmetric_quantization_config (
251286 is_per_channel = True ,
@@ -257,14 +292,16 @@ def _test_dq(
257292 per_op_mode = True ,
258293 )
259294
260- tester = Tester (m , inputs , dynamic_shapes = dynamic_shapes )
295+ tester = Tester (m , m . get_inputs () , dynamic_shapes = dynamic_shapes )
261296 tester .quantize (Quantize (quantization_config = quant_config ))
262297 tester .export ()
263298 tester .check (["torch.ops.quantized_decomposed.choose_qparams" ])
264299 tester .to_edge_transform_and_lower (
265300 ToEdgeTransformAndLower ([DynamicallyQuantizedPartitioner ])
266301 )
267- tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
302+ tester .check_count (
303+ {"torch.ops.higher_order.executorch_call_delegate" : conv_count }
304+ )
268305 tester .check_not (["executorch_exir_dialects_edge__ops_aten_conv2d_default" ])
269306 tester .to_executorch ()
270307 tester .serialize ()
@@ -748,9 +785,13 @@ def forward(self, x):
748785 )
749786
750787 def test_dq_conv2d (self ) -> None :
751- model = Conv2dDynamicQuant ()
752- self ._test_dq (
753- model ,
754- model .get_inputs (),
755- dynamic_shapes = None ,
756- )
788+ model = Conv2dDQ ()
789+ self ._test_dq (model )
790+
791+ def test_dq_conv2d_seq (self ) -> None :
792+ model = Conv2dDQSeq ()
793+ self ._test_dq (model , conv_count = 2 )
794+
795+ def test_dq_conv2d_parallel (self ) -> None :
796+ model = Conv2dDQParallel ()
797+ self ._test_dq (model , conv_count = 2 )
0 commit comments