diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 55c2ca21e1b..62cbfd4602a 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -10,6 +10,7 @@ import torch import torch.fx +import torch.nn.functional as F from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.tosa_utils import get_node_debug_info from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec @@ -142,29 +143,33 @@ def _match_pattern( Each 'pattern' element is composed of a list of disjunctive nodes types. """ - assert len(pattern) == 2, "Only two-nodes patterns supported currently" - - if node.target in pattern[0]: - assert len(node.users) != 0 - parent = node - child = next(iter(node.users)) - elif node.target in pattern[1]: - assert len(node.args) != 0 - parent = node.args[0] # type: ignore[assignment] - child = node - else: - return False - - if len(parent.users) != 1: - return False - - if parent.target not in pattern[0] or child.target not in pattern[1]: - return False - + assert len(pattern) > 0, "No pattern provided" if filter_fn is not None: - return filter_fn(parent) and filter_fn(child) - - return True + if not filter_fn(node): + return False + if len(pattern) == 1: + # Base case where it has passed the filter_fn. Simply look if node.target is in pattern. + return node.target in pattern[0] + if node.target not in [op for sub_pattern in pattern for op in sub_pattern]: + # node.target not in pattern. No need to look at the rest of the pattern. + return False + # Find the index of this node's target in pattern + idx = [node.target in sub_pattern for sub_pattern in pattern].index(True) + left_pattern = pattern[:idx] + # Exclude idx as this contains node.target which we have already matched + right_pattern = pattern[idx + 1 :] + left_condition = True + right_condition = True + # Recursively look at the rest of the pattern by calling this function for + # node's input and user node with updated patterns. + if len(left_pattern) > 0: + parent = node.all_input_nodes[0] + if len(parent.users) != 1: + return False + left_condition = _match_pattern(parent, left_pattern, filter_fn) + if len(right_pattern) > 0: + right_condition = _match_pattern(list(node.users)[0], right_pattern, filter_fn) + return left_condition and right_condition _one_to_one = [ @@ -274,6 +279,58 @@ def any_or_hardtanh_min_zero(n: Node): return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0 if _match_pattern( + node, + [ + [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + ], + [torch.ops.aten.batch_norm.default, F.batch_norm], + [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], + ], + filter_fn=any_or_hardtanh_min_zero, + ): + if node.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), + ] + elif node.target in ( + torch.ops.aten.relu.default, + torch.ops.aten.hardtanh.default, + ): + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + + elif _match_pattern( + node, + [ + [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + ], + [torch.ops.aten.batch_norm.default, F.batch_norm], + ], + ): + if node.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), + ] + elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]: + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif _match_pattern( node, [ [ diff --git a/backends/arm/test/misc/test_bn_relu_folding_qat.py b/backends/arm/test/misc/test_bn_relu_folding_qat.py new file mode 100644 index 00000000000..782783f8205 --- /dev/null +++ b/backends/arm/test/misc/test_bn_relu_folding_qat.py @@ -0,0 +1,66 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn.functional as F +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI + +from executorch.backends.xnnpack.test.tester.tester import Quantize +from torch import nn + + +input_t1 = Tuple[torch.Tensor] # Input x + + +class ConvModule(torch.nn.Module): + input_shape = (1, 28, 28) + batch_size = 64 + test_data: input_t1 = (torch.randn(batch_size, *input_shape),) + + def __init__(self, batch_norm: bool = True) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 16, 3, stride=2) + self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity() + + def forward(self, x: torch.Tensor): + x = self.conv(x) + x = self.bn(x) + x = F.relu(x) + + return x + + +models = { + "conv_bn_relu": ConvModule(batch_norm=True), + "conv_relu": ConvModule(batch_norm=False), +} + + +@common.parametrize("model", models) +def test_qat_tosa_BI(model: torch.nn.Module): + pipeline = TosaPipelineBI[input_t1](model, model.test_data, [], [], qtol=1) + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "0.80": common.TosaSpecification.create_from_string("TOSA-0.80+BI"), + "1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"), + } + tosa_spec = tosa_profiles[tosa_version] + quantizer = TOSAQuantizer(tosa_spec) + pipeline.change_args( + "quantize", + Quantize( + quantizer=quantizer, + quantization_config=get_symmetric_quantization_config(is_qat=True), + is_qat=True, + ), + ) + pipeline.run() diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index cbce817cf4b..dcdafebd6fd 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -55,7 +55,11 @@ ) from executorch.exir.program._program import _transform from torch._export.pass_base import PassType -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) from torch.ao.quantization.quantizer.quantizer import Quantizer from torch.export import export, ExportedProgram from torch.testing import FileCheck @@ -150,10 +154,11 @@ def __init__( quantization_config: Optional[QuantizationConfig] = None, calibrate: bool = True, calibration_samples: Optional[Sequence[Any]] = None, + is_qat: Optional[bool] = False, ): self.quantizer = quantizer or XNNPACKQuantizer() self.quantization_config = ( - quantization_config or get_symmetric_quantization_config() + quantization_config or get_symmetric_quantization_config(is_qat=is_qat) ) self.calibrate = calibrate self.calibration_samples = calibration_samples @@ -161,15 +166,22 @@ def __init__( self.quantizer.set_global(self.quantization_config) self.converted_graph = None + self.is_qat = is_qat def run( self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] ) -> None: assert inputs is not None + if self.is_qat: + artifact.train() captured_graph = export_for_training(artifact, inputs, strict=True).module() assert isinstance(captured_graph, torch.fx.GraphModule) - prepared = prepare_pt2e(captured_graph, self.quantizer) + + if self.is_qat: + prepared = prepare_qat_pt2e(captured_graph, self.quantizer) + else: + prepared = prepare_pt2e(captured_graph, self.quantizer) if self.calibrate: # Calibrate prepared model to provide data to quantization observers.