Skip to content

Commit 6da8b7d

Browse files
committed
Add unit test for dynamic quant conv2d with channels-last permute
1 parent b29030e commit 6da8b7d

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
1111
ChannelsLastTaggedReshapePass,
1212
)
13+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
14+
get_symmetric_quantization_config,
15+
)
1316
from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import (
1417
OpSequencesAddConv2d,
1518
)
16-
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
19+
from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester
1720

1821

1922
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
@@ -35,6 +38,10 @@ def setUp(self):
3538
dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
3639
conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default"
3740
relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default"
41+
choose_qparams_name = (
42+
"executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
43+
)
44+
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
3845

3946
def test_fp32_channels_last_tagged_reshape_pass(self):
4047
for module, num_reshape in self.modules.items():
@@ -179,3 +186,37 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
179186
)
180187
.run_method_and_compare_outputs()
181188
)
189+
190+
class Conv2dDynamicQuant(torch.nn.Module):
191+
def __init__(self):
192+
super().__init__()
193+
self.conv = torch.nn.Conv2d(3, 10, 3)
194+
195+
def forward(self, x):
196+
return self.conv(x)
197+
198+
def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
199+
(
200+
Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
201+
.quantize(
202+
Quantize(
203+
quantization_config=get_symmetric_quantization_config(
204+
is_dynamic=True
205+
)
206+
)
207+
)
208+
.export()
209+
.to_edge()
210+
.run_passes(self.PassStage)
211+
.check(
212+
[
213+
self.to_copy_name,
214+
self.choose_qparams_name,
215+
self.dynamic_quant_name,
216+
self.dequant_name,
217+
self.conv_name,
218+
self.to_copy_name,
219+
]
220+
)
221+
.run_method_and_compare_outputs()
222+
)

0 commit comments

Comments
 (0)