Skip to content

Commit b29030e

Browse files
committed
Add unit tests for dynamic quant sequential and parallel convs
1 parent 228dc0b commit b29030e

File tree

1 file changed

+55
-14
lines changed

1 file changed

+55
-14
lines changed

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,10 @@ def get_inputs(self):
173173
return (torch.randn(2, 2, 4, 4),)
174174

175175

176-
class Conv2dDynamicQuant(torch.nn.Module):
176+
class Conv2dDQ(torch.nn.Module):
177177
def __init__(self):
178178
super().__init__()
179-
self.conv = torch.nn.Conv2d(3, 10, 3)
180-
self.conv.weight.requires_grad = False
181-
self.conv.bias.requires_grad = False
179+
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)
182180

183181
def forward(self, x):
184182
return self.conv(x)
@@ -187,6 +185,43 @@ def get_inputs(self):
187185
return (torch.randn(1, 3, 8, 8),)
188186

189187

188+
class Conv2dDQSeq(torch.nn.Module):
189+
def __init__(self):
190+
super().__init__()
191+
self.first = torch.nn.Conv2d(
192+
in_channels=3, out_channels=8, kernel_size=3, padding=1
193+
)
194+
self.second = torch.nn.Conv2d(
195+
in_channels=8, out_channels=10, kernel_size=3, padding=1
196+
)
197+
198+
def forward(self, x):
199+
y = self.first(x)
200+
return self.second(y)
201+
202+
def get_inputs(self):
203+
return (torch.randn(1, 3, 8, 8),)
204+
205+
206+
class Conv2dDQParallel(torch.nn.Module):
207+
def __init__(self):
208+
super().__init__()
209+
self.first = torch.nn.Conv2d(
210+
in_channels=3, out_channels=8, kernel_size=3, padding=1
211+
)
212+
self.second = torch.nn.Conv2d(
213+
in_channels=3, out_channels=10, kernel_size=3, padding=1
214+
)
215+
216+
def forward(self, x):
217+
first = self.first(x)
218+
second = self.second(x)
219+
return first, second
220+
221+
def get_inputs(self):
222+
return (torch.randn(1, 3, 8, 8),)
223+
224+
190225
class TestConv2d(unittest.TestCase):
191226
def setUp(self):
192227
torch._dynamo.reset()
@@ -244,8 +279,8 @@ def _test(
244279
def _test_dq(
245280
self,
246281
m: torch.nn.Module,
247-
inputs,
248-
dynamic_shapes,
282+
conv_count=1,
283+
dynamic_shapes=None,
249284
):
250285
quant_config = get_symmetric_quantization_config(
251286
is_per_channel=True,
@@ -257,14 +292,16 @@ def _test_dq(
257292
per_op_mode=True,
258293
)
259294

260-
tester = Tester(m, inputs, dynamic_shapes=dynamic_shapes)
295+
tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes)
261296
tester.quantize(Quantize(quantization_config=quant_config))
262297
tester.export()
263298
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
264299
tester.to_edge_transform_and_lower(
265300
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
266301
)
267-
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
302+
tester.check_count(
303+
{"torch.ops.higher_order.executorch_call_delegate": conv_count}
304+
)
268305
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
269306
tester.to_executorch()
270307
tester.serialize()
@@ -748,9 +785,13 @@ def forward(self, x):
748785
)
749786

750787
def test_dq_conv2d(self) -> None:
751-
model = Conv2dDynamicQuant()
752-
self._test_dq(
753-
model,
754-
model.get_inputs(),
755-
dynamic_shapes=None,
756-
)
788+
model = Conv2dDQ()
789+
self._test_dq(model)
790+
791+
def test_dq_conv2d_seq(self) -> None:
792+
model = Conv2dDQSeq()
793+
self._test_dq(model, conv_count=2)
794+
795+
def test_dq_conv2d_parallel(self) -> None:
796+
model = Conv2dDQParallel()
797+
self._test_dq(model, conv_count=2)

0 commit comments

Comments
 (0)