diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 130eda03f88..c07d27e4231 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer): QuantPattern("linear_relu", False, False, LINEAR_TARGETS), QuantPattern("linear", True, False, LINEAR_TARGETS), QuantPattern("conv", True, False, CONV_TARGETS), - QuantPattern("conv_transpose", False, False, CONV_TARGETS), + QuantPattern("conv_transpose", True, False, CONV_TARGETS), QuantPattern("conv_relu", False, False, CONV_TARGETS), QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 2ebf69da4f5..3d687d0b513 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -4,7 +4,10 @@ import torch import torch.nn.functional as F -from executorch.backends.xnnpack.utils.utils import is_depthwise_conv +from executorch.backends.xnnpack.utils.utils import ( + get_groups_from_conv, + is_depthwise_conv, +) from torch._subclasses import FakeTensor from torch.fx import Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( @@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None: return decorator +def change_quantization_config( + original_qspec, + dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + ch_axis=None, + is_dynamic=None, + observer_or_fake_quant_ctr=None, +): + return QuantizationSpec( + dtype=dtype or original_qspec.dtype, + quant_min=quant_min or original_qspec.quant_min, + quant_max=quant_max or original_qspec.quant_max, + qscheme=qscheme or original_qspec.qscheme, + ch_axis=ch_axis or original_qspec.ch_axis, + is_dynamic=is_dynamic or original_qspec.is_dynamic, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr + or original_qspec.observer_or_fake_quant_ctr, + ) + + def is_relu_node(node: Node) -> bool: """ Check if a given node is a relu node @@ -231,6 +256,9 @@ def _do_annotate_conv( if is_relu_node(user): continue + # Tracks conditions for whether or not to skip + skip = False + input_qspec_map = {} input_act = conv_node.args[0] assert isinstance(input_act, Node) @@ -239,35 +267,33 @@ def _do_annotate_conv( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) + num_groups = get_groups_from_conv(conv_node) + + # skip if transposed conv has more than 1 group + skip = skip or (is_conv_transpose and num_groups != 1) + print(f"{skip} conv transpose and num_groups") + if is_conv_transpose: # transposed convs per output channel quantization - weight_qspec = QuantizationSpec( - dtype=weight_qspec.dtype, - quant_min=weight_qspec.quant_min, - quant_max=weight_qspec.quant_max, - qscheme=weight_qspec.qscheme, - ch_axis=1, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, - ) - input_qspec_map[weight] = weight_qspec + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) - # Only annotate dynamically quantized conv if it's 2D and not depthwise - if ( + input_qspec_map[weight] = weight_qspec + is_dynamic = ( quantization_config and quantization_config.input_activation and quantization_config.input_activation.is_dynamic - ): + ) + + # Only annotate dynamically quantized conv if it's 2D and not depthwise + if is_dynamic: weight_val = weight.meta.get("val", None) weight_shape = getattr(weight_val, "shape", None) - # Skip if not a 4D weight tensor (i.e. not conv2d) - if weight_shape is not None and len(weight_shape) != 4: - continue - + skip = skip or (weight_shape is not None and len(weight_shape) != 4) # Skip if depthwise (default to groups=1 since it's not an arg) - if is_depthwise_conv(weight_shape, 1, is_conv_transpose): - continue + skip = skip or ( + not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False) + ) # adding weight node to the partition as well partition = [conv_node, conv_node.args[1]] @@ -277,7 +303,7 @@ def _do_annotate_conv( input_qspec_map[bias] = get_bias_qspec(quantization_config) partition.append(bias) - if _is_annotated(partition): + if _is_annotated(partition) or skip: continue if filter_fn and any(not filter_fn(n) for n in partition): @@ -324,17 +350,10 @@ def _do_annotate_conv_relu( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) + groups = get_groups_from_conv(conv_node) if is_conv_transpose: # transposed convs per output channel quantization - weight_qspec = QuantizationSpec( - dtype=weight_qspec.dtype, - quant_min=weight_qspec.quant_min, - quant_max=weight_qspec.quant_max, - qscheme=weight_qspec.qscheme, - ch_axis=1, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, - ) + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well @@ -347,6 +366,9 @@ def _do_annotate_conv_relu( if _is_annotated(partition): continue + if is_conv_transpose and groups != 1: + continue + if filter_fn and any(not filter_fn(n) for n in partition): continue diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index d838ef0ffe9..2a0a82d99b6 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -174,14 +174,11 @@ def get_inputs(self): class Conv2dDQSeq(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=8, out_channels=10, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=8, out_channels=10, kernel_size=3, padding=1) def forward(self, x): y = self.first(x) @@ -192,14 +189,11 @@ def get_inputs(self): class Conv2dDQParallel(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=3, out_channels=10, kernel_size=3, padding=1) def forward(self, x): first = self.first(x) @@ -266,8 +260,7 @@ def _test_dq( ) DynamicallyQuantizedPartitioner = XnnpackPartitioner( - config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, - per_op_mode=True, + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True ) tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes) @@ -349,11 +342,10 @@ def test_fp32_conv2d_depthwise(self): ) def test_qs8_conv2d_depthwise(self): - for transpose in (True, False): - self._test( - Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose), - quant_config=get_symmetric_quantization_config(), - ) + self._test( + Conv2d(groups=2, in_channels=2, out_channels=6), + quant_config=get_symmetric_quantization_config(), + ) def test_fp32_conv2d_bn(self): class Conv2dBatchNorm(torch.nn.Module): @@ -515,17 +507,14 @@ def forward(self, x): def get_inputs(self): return (torch.randn(batches, in_channels, height, width) * 11,) - for transpose in (True, False): - for per_channel_quant in (False, True): - if transpose and per_channel_quant: - continue - model = ModelConvReLU(transpose=transpose) - self._test( - model, - quant_config=get_symmetric_quantization_config( - is_per_channel=per_channel_quant - ), - ) + for per_channel_quant in (False, True): + model = ModelConvReLU() + self._test( + model, + quant_config=get_symmetric_quantization_config( + is_per_channel=per_channel_quant + ), + ) def test_qs8_conv2d_relu_seq(self): class ConvReLUSeq(torch.nn.Module): @@ -728,3 +717,31 @@ def test_dq_conv2d_parallel(self) -> None: model = Conv2dDQParallel() conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose(self) -> None: + model = Conv2d( + in_channels=3, + out_channels=10, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batches=1, + width=8, + height=8, + transpose=True, + ) + self._test_dq(model) + + def test_dq_conv2d_transpose_seq(self) -> None: + model = Conv2dDQSeq(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose_parallel(self) -> None: + model = Conv2dDQParallel(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index b23fd444117..a8f3178f98f 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -25,6 +25,7 @@ is_lifted_tensor_constant, is_param, ) +from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node ### XNNPACK Capture ### @@ -160,6 +161,36 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: return source_fn[1] +def get_groups_from_conv(conv_node: torch.fx.Node) -> int: + if _is_conv_node(conv_node): + in_node = cast(torch.fx.Node, conv_node.args[0]) + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the input shape + + # input shape is (N, C_in, H_in, W_in) + in_channels = in_node.meta["val"].shape[1] + + # weight shape is (C_out, C_in/groups, kernel_size[0], kernel_size[1]) + in_groups = weight_node.meta["val"].shape[1] + + return in_channels // in_groups + elif _is_conv_transpose_node(conv_node): + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the output shape + + # weight shape is (C_in, C_out/groups, kernel_size[0], kernel_size[1]) + out_groups = weight_node.meta["val"].shape[1] + + # output shape is (N, C_out, H_out, W_out) + out_channels = conv_node.meta["val"].shape[1] + + return out_channels // out_groups + + raise RuntimeError(f"expected {conv_node} to be a conv or conv_transpose node") + + def is_depthwise_conv( kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False ) -> bool: