Skip to content

Commit 6b44c4b

Browse files
committed
Refactor dq conv2d test
1 parent 7150872 commit 6b44c4b

File tree

1 file changed

+21
-34
lines changed

1 file changed

+21
-34
lines changed

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,20 @@ def get_inputs(self):
176176
return (torch.randn(2, 2, 4, 4),)
177177

178178

179+
class DQConv2d(torch.nn.Module):
180+
def __init__(self):
181+
super().__init__()
182+
self.conv = torch.nn.Conv2d(3, 10, 3)
183+
self.conv.weight.requires_grad = False
184+
self.conv.bias.requires_grad = False
185+
186+
def forward(self, x):
187+
return self.conv(x)
188+
189+
def get_inputs(self):
190+
return (torch.randn(1, 3, 8, 8),)
191+
192+
179193
class TestConv2d(unittest.TestCase):
180194
def setUp(self):
181195
torch._dynamo.reset()
@@ -230,12 +244,11 @@ def _test(
230244
.run_method_and_compare_outputs(qtol=1)
231245
)
232246

233-
def _test_dq_conv2d(
247+
def _test_dq(
234248
self,
235249
m: torch.nn.Module,
236250
inputs,
237251
dynamic_shapes,
238-
atol=5e-02,
239252
):
240253
quant_config = get_symmetric_quantization_config(
241254
is_per_channel=True,
@@ -250,21 +263,15 @@ def _test_dq_conv2d(
250263
tester = Tester(m, inputs, dynamic_shapes=dynamic_shapes)
251264
tester.quantize(Quantize(quantization_config=quant_config))
252265
tester.export()
253-
254266
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
255-
256267
tester.to_edge_transform_and_lower(
257268
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
258269
)
259-
260270
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
261271
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
262-
263272
tester.to_executorch()
264-
# tester.serialize()
265-
tester.serialize().dump_artifact("conv2d.pte")
266-
267-
tester.run_method_and_compare_outputs(atol=atol)
273+
tester.serialize()
274+
tester.run_method_and_compare_outputs(qtol=1)
268275

269276
def test_fp16_conv2d(self) -> None:
270277
for transpose in (True, False):
@@ -743,30 +750,10 @@ def forward(self, x):
743750
.run_method_and_compare_outputs(qtol=1)
744751
)
745752

746-
def test_dq_conv2d(self) -> None:
747-
class SimpleConv2d(torch.nn.Module):
748-
def __init__(self):
749-
super().__init__()
750-
self.conv = torch.nn.Conv2d(
751-
3,
752-
10,
753-
3,
754-
)
755-
self.conv.weight.requires_grad = False
756-
self.conv.bias.requires_grad = False
757-
758-
def forward(self, x):
759-
return self.conv(x)
760-
761-
def get_inputs(self):
762-
return (torch.randn(1, 3, 8, 8),)
763-
764-
model = SimpleConv2d()
765-
inputs = model.get_inputs()
766-
767-
self._test_dq_conv2d(
753+
def test_qs8_dq_conv2d(self) -> None:
754+
model = DQConv2d()
755+
self._test_dq(
768756
model,
769-
inputs,
757+
model.get_inputs(),
770758
dynamic_shapes=None,
771-
atol=3.0,
772759
)

0 commit comments

Comments
 (0)