From 808ff19e3cf8c2afeb8383e7584f8a839b487c80 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Wed, 30 Apr 2025 18:30:02 -0700 Subject: [PATCH] Enable Operator Selection in Quantizer (#10569) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/10569 Fixing some broken tests after we enabled dq convs, we enable operator selection for quantization Reviewed By: derekxu, digantdesai Differential Revision: D73898719 --- .../xnnpack/quantizer/xnnpack_quantizer.py | 151 ++++++++---------- .../test/quantizer/test_pt2e_quantization.py | 16 +- .../test/quantizer/test_xnnpack_quantizer.py | 2 +- 3 files changed, 84 insertions(+), 85 deletions(-) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index fdabd0383e6..cbd98bdb4c3 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -3,7 +3,8 @@ import copy import functools -from typing import Any, Callable, Optional, TYPE_CHECKING +from dataclasses import dataclass +from typing import Any, Callable, Optional, Set, TYPE_CHECKING import torch import torch._dynamo as torchdynamo @@ -235,37 +236,52 @@ def not_module_type_or_name_filter(n: Node) -> bool: return not_module_type_or_name_filter -class XNNPACKQuantizer(Quantizer): - supported_config_and_operators = _get_supported_config_and_operators() - STATIC_QAT_ONLY_OPS = [ - "conv_bn_relu", - "conv_bn", - "conv_transpose_bn_relu", - "conv_transpose_bn", - ] +@dataclass +class QuantPattern: + name: str + is_dynamic: bool + is_qat: bool + op_overloads: Set[torch._ops.OpOverloadPacket] + + +CONV_TARGETS = { + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.convolution.default, +} + +LINEAR_TARGETS = { + torch.ops.aten.linear.default, +} + +ADAPTIVE_AVG_POOL2D_TARGETS = {torch.ops.aten.adaptive_avg_pool2d.default} + +ADD_TARGETS = {torch.ops.aten.add.Tensor} + +MUL_TARGETS = {torch.ops.aten.mul.Tensor} + +CAT_TARGETS = {torch.ops.aten.cat.default} - # static quantization ops (both PTQ and QAT) - # Preserve the order that fusions come before singular ops - STATIC_OPS = [ - "linear_relu", - "linear", - "conv", - "conv_transpose", - "conv_relu", - "conv_transpose_relu", - "adaptive_avg_pool2d", - # TODO: move this to BoltNNQuantizer? - "gru_io_only", - "add_relu", - "add", - "mul_relu", - "mul", - "cat", - ] - DYNAMIC_OPS = [ - "linear", - "conv", +class XNNPACKQuantizer(Quantizer): + supported_config_and_operators = _get_supported_config_and_operators() + SUPPORTED_PATTERNS = [ + QuantPattern("conv_bn_relu", False, True, CONV_TARGETS), + QuantPattern("conv_bn", False, True, CONV_TARGETS), + QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS), + QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS), + QuantPattern("linear_relu", False, False, LINEAR_TARGETS), + QuantPattern("linear", True, False, LINEAR_TARGETS), + QuantPattern("conv", True, False, CONV_TARGETS), + QuantPattern("conv_transpose", False, False, CONV_TARGETS), + QuantPattern("conv_relu", False, False, CONV_TARGETS), + QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), + QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), + QuantPattern("add_relu", False, False, ADD_TARGETS), + QuantPattern("add", False, False, ADD_TARGETS), + QuantPattern("mul_relu", False, False, MUL_TARGETS), + QuantPattern("mul", False, False, MUL_TARGETS), + QuantPattern("cat", False, False, CAT_TARGETS), ] def __init__(self) -> None: @@ -347,83 +363,58 @@ def transform_for_annotation( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """just handling global spec for now""" - # hacked for handling dynamic linear quant. will fix later. - if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] - model = self._annotate_for_dynamic_quantization_config(model) - else: - model = self._annotate_for_static_quantization_config(model) + model = self._annotate_for_quantization_config(model) propagate_annotation(model) return model - def _annotate_all_static_patterns( + def _annotate_all_patterns( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, - ) -> torch.fx.GraphModule: + operator_target: Optional[torch._ops.OpOverloadPacket] = None, + ): # TODO: implement the support for None to be canceling out previous annotations if quantization_config is None: return model - if quantization_config.is_qat: - for op in self.STATIC_QAT_ONLY_OPS: - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) - for op in self.STATIC_OPS: - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) - return model + for pattern in self.SUPPORTED_PATTERNS: + if operator_target and operator_target not in pattern.op_overloads: + # if operator_target is specified, skip patterns that aren't + # associated with that target + continue + if quantization_config.input_activation.is_dynamic and pattern.is_dynamic: + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) + elif quantization_config.is_qat and pattern.is_qat: + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) + elif not quantization_config.input_activation.is_dynamic: + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) - def _annotate_all_dynamic_patterns( - self, - model: torch.fx.GraphModule, - quantization_config: Optional[QuantizationConfig], - filter_fn: Optional[Callable[[Node], bool]] = None, - ) -> torch.fx.GraphModule: - # TODO: implement the support for None to be canceling out previous annotations - if quantization_config is None: - return model - - for op in self.DYNAMIC_OPS: - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) return model - def _annotate_for_static_quantization_config( + def _annotate_for_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): - self._annotate_all_static_patterns( + self._annotate_all_patterns( model, config, _get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) for module_type, config in self.module_type_config.items(): - self._annotate_all_static_patterns( + self._annotate_all_patterns( model, config, _get_module_type_filter(module_type) ) - self._annotate_all_static_patterns( - model, - self.global_config, - _get_not_module_type_or_name_filter(tp_list, module_name_list), - ) - return model - - def _annotate_for_dynamic_quantization_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: - module_name_list = list(self.module_name_config.keys()) - for module_name, config in self.module_name_config.items(): - self._annotate_all_dynamic_patterns( - model, config, _get_module_name_filter(module_name) - ) - - tp_list = list(self.module_type_config.keys()) - for module_type, config in self.module_type_config.items(): - self._annotate_all_dynamic_patterns( - model, config, _get_module_type_filter(module_type) + for op, config in self.operator_type_config.items(): + self._annotate_all_patterns( + model, + config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + op, ) - - self._annotate_all_dynamic_patterns( + self._annotate_all_patterns( model, self.global_config, _get_not_module_type_or_name_filter(tp_list, module_name_list), diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 34b6f745044..4243441118e 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -172,10 +172,14 @@ def test_composable_quantizer_linear_conv(self) -> None: quantization_config_dynamic = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=True ) - dynamic_quantizer.set_global(quantization_config_dynamic) + dynamic_quantizer.set_operator_type( + torch.ops.aten.linear.default, quantization_config_dynamic + ) static_quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) - static_quantizer.set_global(quantization_config) + static_quantizer.set_operator_type( + torch.ops.aten.conv2d.default, quantization_config + ) # Note that dynamic quantization must be applied first here. # this is because static quantizer also quantizes linear with static qspec # and if we apply static_quantizer first then dynamic_quantizer cannot be applied @@ -271,10 +275,14 @@ def test_embedding_conv_linear_quantization(self) -> None: quantization_config_dynamic = get_symmetric_quantization_config( is_per_channel=True, is_dynamic=True ) - dynamic_quantizer.set_global(quantization_config_dynamic) + dynamic_quantizer.set_operator_type( + torch.ops.aten.linear.default, quantization_config_dynamic + ) static_quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) - static_quantizer.set_global(quantization_config) + static_quantizer.set_operator_type( + torch.ops.aten.conv2d.default, quantization_config + ) composed_quantizer = ComposableQuantizer( [embedding_quantizer, dynamic_quantizer, static_quantizer] ) diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 856030755af..f2a94325b92 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -665,7 +665,7 @@ def test_dynamic_linear_with_conv(self): quantization_config = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=True ) - quantizer.set_global(quantization_config) + quantizer.set_operator_type(torch.ops.aten.linear.default, quantization_config) m_eager = TestHelperModules.ConvLinearWPermute().eval() node_occurrence = {