From aa92ce6ec002b2eb2752f21d34c79a93ecc37470 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 16 Jun 2025 11:07:45 -0700 Subject: [PATCH] [XNNPACK Quantizer] Select between TConvs and Convs Allow selection of Difference between transposed convs and regular convs. Previously, we grouped all conv targets together (transposed and regular convs), but now we enable better per-operator selection Differential Revision: [D76641838](https://our.internmc.facebook.com/intern/diff/D76641838/) [ghstack-poisoned] --- .../xnnpack/quantizer/xnnpack_quantizer.py | 17 ++- .../test/quantizer/test_xnnpack_quantizer.py | 110 ++++++++++++++++++ 2 files changed, 123 insertions(+), 4 deletions(-) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index c07d27e4231..3c82a65ad71 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -251,6 +251,15 @@ class QuantPattern: torch.ops.aten.convolution.default, } +CONV_TRANSPOSE_TARGETS = { + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d, + torch.ops.aten.conv_transpose3d.input, +} + LINEAR_TARGETS = { torch.ops.aten.linear.default, } @@ -269,14 +278,14 @@ class XNNPACKQuantizer(Quantizer): SUPPORTED_PATTERNS = [ QuantPattern("conv_bn_relu", False, True, CONV_TARGETS), QuantPattern("conv_bn", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS), + QuantPattern("conv_transpose_bn_relu", False, True, CONV_TRANSPOSE_TARGETS), + QuantPattern("conv_transpose_bn", False, True, CONV_TRANSPOSE_TARGETS), QuantPattern("linear_relu", False, False, LINEAR_TARGETS), QuantPattern("linear", True, False, LINEAR_TARGETS), QuantPattern("conv", True, False, CONV_TARGETS), - QuantPattern("conv_transpose", True, False, CONV_TARGETS), + QuantPattern("conv_transpose", True, False, CONV_TRANSPOSE_TARGETS), QuantPattern("conv_relu", False, False, CONV_TARGETS), - QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), + QuantPattern("conv_transpose_relu", False, False, CONV_TRANSPOSE_TARGETS), QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), QuantPattern("add_relu", False, False, ADD_TARGETS), QuantPattern("add", False, False, ADD_TARGETS), diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 0a317ad8822..84b1a932a5b 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -120,6 +120,116 @@ def test_conv1d_with_conv2d(self): node_list, ) + def test_q_tconv_and_conv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type( + torch.ops.aten.conv_transpose2d.input, quantization_config + ) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + + def test_q_conv2_and_tconv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type(torch.ops.aten.conv2d.default, quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + def test_linear(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True)