Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 79 additions & 22 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import torch.fx
import torch.nn.functional as F
from executorch.backends.arm.quantizer import QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
Expand Down Expand Up @@ -142,29 +143,33 @@ def _match_pattern(

Each 'pattern' element is composed of a list of disjunctive nodes types.
"""
assert len(pattern) == 2, "Only two-nodes patterns supported currently"

if node.target in pattern[0]:
assert len(node.users) != 0
parent = node
child = next(iter(node.users))
elif node.target in pattern[1]:
assert len(node.args) != 0
parent = node.args[0] # type: ignore[assignment]
child = node
else:
return False

if len(parent.users) != 1:
return False

if parent.target not in pattern[0] or child.target not in pattern[1]:
return False

assert len(pattern) > 0, "No pattern provided"
if filter_fn is not None:
return filter_fn(parent) and filter_fn(child)

return True
if not filter_fn(node):
return False
if len(pattern) == 1:
# Base case where it has passed the filter_fn. Simply look if node.target is in pattern.
return node.target in pattern[0]
if node.target not in [op for sub_pattern in pattern for op in sub_pattern]:
# node.target not in pattern. No need to look at the rest of the pattern.
return False
# Find the index of this node's target in pattern
idx = [node.target in sub_pattern for sub_pattern in pattern].index(True)
left_pattern = pattern[:idx]
# Exclude idx as this contains node.target which we have already matched
right_pattern = pattern[idx + 1 :]
left_condition = True
right_condition = True
# Recursively look at the rest of the pattern by calling this function for
# node's input and user node with updated patterns.
if len(left_pattern) > 0:
parent = node.all_input_nodes[0]
if len(parent.users) != 1:
return False
left_condition = _match_pattern(parent, left_pattern, filter_fn)
if len(right_pattern) > 0:
right_condition = _match_pattern(list(node.users)[0], right_pattern, filter_fn)
return left_condition and right_condition


_one_to_one = [
Expand Down Expand Up @@ -274,6 +279,58 @@ def any_or_hardtanh_min_zero(n: Node):
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0

if _match_pattern(
node,
[
[
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
],
[torch.ops.aten.batch_norm.default, F.batch_norm],
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
],
filter_fn=any_or_hardtanh_min_zero,
):
if node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in (
torch.ops.aten.relu.default,
torch.ops.aten.hardtanh.default,
):
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

elif _match_pattern(
node,
[
[
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
],
[torch.ops.aten.batch_norm.default, F.batch_norm],
],
):
if node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif _match_pattern(
node,
[
[
Expand Down
66 changes: 66 additions & 0 deletions backends/arm/test/misc/test_bn_relu_folding_qat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
import torch.nn.functional as F
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI

from executorch.backends.xnnpack.test.tester.tester import Quantize
from torch import nn


input_t1 = Tuple[torch.Tensor] # Input x


class ConvModule(torch.nn.Module):
input_shape = (1, 28, 28)
batch_size = 64
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)

def __init__(self, batch_norm: bool = True) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 16, 3, stride=2)
self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity()

def forward(self, x: torch.Tensor):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)

return x


models = {
"conv_bn_relu": ConvModule(batch_norm=True),
"conv_relu": ConvModule(batch_norm=False),
}


@common.parametrize("model", models)
def test_qat_tosa_BI(model: torch.nn.Module):
pipeline = TosaPipelineBI[input_t1](model, model.test_data, [], [], qtol=1)
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"0.80": common.TosaSpecification.create_from_string("TOSA-0.80+BI"),
"1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"),
}
tosa_spec = tosa_profiles[tosa_version]
quantizer = TOSAQuantizer(tosa_spec)
pipeline.change_args(
"quantize",
Quantize(
quantizer=quantizer,
quantization_config=get_symmetric_quantization_config(is_qat=True),
is_qat=True,
),
)
pipeline.run()
18 changes: 15 additions & 3 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@
)
from executorch.exir.program._program import _transform
from torch._export.pass_base import PassType
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.ao.quantization.quantizer.quantizer import Quantizer
from torch.export import export, ExportedProgram
from torch.testing import FileCheck
Expand Down Expand Up @@ -150,26 +154,34 @@ def __init__(
quantization_config: Optional[QuantizationConfig] = None,
calibrate: bool = True,
calibration_samples: Optional[Sequence[Any]] = None,
is_qat: Optional[bool] = False,
):
self.quantizer = quantizer or XNNPACKQuantizer()
self.quantization_config = (
quantization_config or get_symmetric_quantization_config()
quantization_config or get_symmetric_quantization_config(is_qat=is_qat)
)
self.calibrate = calibrate
self.calibration_samples = calibration_samples

self.quantizer.set_global(self.quantization_config)

self.converted_graph = None
self.is_qat = is_qat

def run(
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
) -> None:
assert inputs is not None
if self.is_qat:
artifact.train()
captured_graph = export_for_training(artifact, inputs, strict=True).module()

assert isinstance(captured_graph, torch.fx.GraphModule)
prepared = prepare_pt2e(captured_graph, self.quantizer)

if self.is_qat:
prepared = prepare_qat_pt2e(captured_graph, self.quantizer)
else:
prepared = prepare_pt2e(captured_graph, self.quantizer)

if self.calibrate:
# Calibrate prepared model to provide data to quantization observers.
Expand Down
Loading