Skip to content

Commit 6ab8723

Browse files
Cortex_m backend: Add quantizer + avoid linear decomp (#15459)
Changes to_edge_and_transform to to_edge which supports the preserver_ops arg of the EdgeCompileConfig to avoid decomposing of the linear op. This significantly simplifies lowering the linear operator as it does not have to be re-fused. Adds a cortex_m quantizer, with the intention to be general enough to be used for a general MCU. It is implemented as a ComposableQuantizer using multiple instances of a new OperatorConfigQuantizer class. This gives a number of abstraction levels for configuration - McuQuantizer - ComposableQuantizer - OperatorConfig - QuantizerConfig - QuantizationSpec The new quantizer also adds a transform_for_annotation pass pipeline which allows to fix scalar + tensor operations. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 135e3b6 commit 6ab8723

File tree

7 files changed

+368
-92
lines changed

7 files changed

+368
-92
lines changed

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
from executorch.backends.arm._passes import ScalarsToAttributePass
78
from executorch.backends.cortex_m.passes import (
89
QuantizedLinearFusionPass,
910
QuantizedOpFusionPass,
@@ -25,5 +26,16 @@ class CortexMPassManager(XNNPACKPassManager):
2526
QuantizedLinearFusionPass,
2627
]
2728

29+
pass_list_transform_for_annotation: list[ExportPass] = [
30+
ScalarsToAttributePass,
31+
ReplaceScalarWithTensorArgPass,
32+
]
33+
2834
def __init__(self, exported_program, passes=None):
2935
super().__init__(exported_program, passes or self.pass_list)
36+
37+
def transform_for_annotation(self, model):
38+
passes = self.pass_list_transform_for_annotation
39+
for p in passes:
40+
model = p().call(model).graph_module
41+
return model
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""
6+
Operator configs maps a list of operators/operator patterns to a quantization configuration.
7+
These can be used with the OperatorConfigQuantizer to quantize models based on operator patterns.
8+
"""
9+
10+
import torch
11+
12+
from executorch.backends.cortex_m.quantizer.quantization_configs import (
13+
INT8_PER_TENSOR_CONFIG,
14+
)
15+
from torchao.quantization.pt2e.quantizer import OperatorConfig
16+
17+
# ----------------- OPERATOR PATTERN PRESETS -----------------
18+
BINARY_OP_PATTERNS = [
19+
[torch.ops.aten.add.Tensor],
20+
]
21+
22+
LINEAR_OP_PATTERNS = [
23+
[torch.ops.aten.linear.default],
24+
[torch.ops.aten.linear.default, torch.ops.aten.relu.default],
25+
]
26+
27+
# ----------------- OPERATOR CONFIG PRESETS -----------------
28+
INT8_BINARY_OPS_OPERATOR_CONFIG = OperatorConfig(
29+
INT8_PER_TENSOR_CONFIG, BINARY_OP_PATTERNS
30+
)
31+
32+
INT8_LINEAR_OPERATOR_CONFIG = OperatorConfig(
33+
INT8_PER_TENSOR_CONFIG,
34+
LINEAR_OP_PATTERNS,
35+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
9+
from torchao.quantization.pt2e.quantizer import (
10+
DerivedQuantizationSpec,
11+
QuantizationConfig,
12+
QuantizationSpec,
13+
)
14+
15+
# ----------------- QUANTIZATION SPEC PRESETS -----------------
16+
INT8_WEIGHT_PER_TENSOR_QSPEC = QuantizationSpec(
17+
dtype=torch.int8,
18+
observer_or_fake_quant_ctr=MinMaxObserver,
19+
qscheme=torch.per_tensor_symmetric,
20+
)
21+
22+
INT8_WEIGHT_PER_CHANNEL_QSPEC = QuantizationSpec(
23+
dtype=torch.int8,
24+
observer_or_fake_quant_ctr=MinMaxObserver,
25+
qscheme=torch.per_channel_symmetric,
26+
)
27+
28+
INT8_ACTIVATION_PER_TENSOR_QSPEC = QuantizationSpec(
29+
dtype=torch.int8,
30+
observer_or_fake_quant_ctr=HistogramObserver,
31+
qscheme=torch.per_tensor_affine,
32+
)
33+
34+
INT8_ACTIVATION_PER_CHANNEL_QSPEC = QuantizationSpec(
35+
dtype=torch.int8,
36+
observer_or_fake_quant_ctr=HistogramObserver,
37+
qscheme=torch.per_channel_affine,
38+
)
39+
40+
41+
def _derive_bias_qparams_fn(
42+
obs_or_fqs,
43+
) -> tuple[torch.Tensor, torch.Tensor]:
44+
if len(obs_or_fqs) != 2:
45+
raise ValueError(
46+
f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
47+
)
48+
act_obs_or_fq = obs_or_fqs[0]
49+
weight_obs_or_fq = obs_or_fqs[1]
50+
act_scale, _ = act_obs_or_fq.calculate_qparams()
51+
weight_scale, _ = weight_obs_or_fq.calculate_qparams()
52+
return act_scale * weight_scale, torch.full_like(
53+
weight_scale, fill_value=0, dtype=torch.int32
54+
)
55+
56+
57+
def _get_int32_bias_qspec(node):
58+
return DerivedQuantizationSpec(
59+
derived_from=[(node.args[0], node), (node.args[1], node)], # type: ignore[list-item]
60+
derive_qparams_fn=_derive_bias_qparams_fn,
61+
dtype=torch.int32,
62+
quant_min=torch.iinfo(torch.int32).min,
63+
quant_max=torch.iinfo(torch.int32).max - 1,
64+
qscheme=torch.per_tensor_symmetric,
65+
)
66+
67+
68+
# ----------------- QUANTIZATION CONFIG PRESETS -----------------
69+
INT8_PER_TENSOR_CONFIG = QuantizationConfig(
70+
INT8_ACTIVATION_PER_TENSOR_QSPEC,
71+
INT8_ACTIVATION_PER_TENSOR_QSPEC,
72+
INT8_WEIGHT_PER_TENSOR_QSPEC,
73+
_get_int32_bias_qspec,
74+
)
75+
76+
77+
INT8_PER_CHANNEL_CONFIG = QuantizationConfig(
78+
INT8_ACTIVATION_PER_CHANNEL_QSPEC,
79+
INT8_ACTIVATION_PER_CHANNEL_QSPEC,
80+
INT8_WEIGHT_PER_CHANNEL_QSPEC,
81+
_get_int32_bias_qspec,
82+
)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Callable, List, Optional
8+
9+
import torch
10+
11+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
12+
13+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
14+
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
15+
from executorch.backends.cortex_m.quantizer.operator_configs import (
16+
INT8_BINARY_OPS_OPERATOR_CONFIG,
17+
INT8_LINEAR_OPERATOR_CONFIG,
18+
)
19+
from torch._ops import OpOverload
20+
from torch.fx import GraphModule, Node
21+
from torchao.quantization.pt2e.quantizer import (
22+
ComposableQuantizer,
23+
QuantizationAnnotation,
24+
Quantizer,
25+
)
26+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
27+
28+
29+
class CortexMQuantizer(ComposableQuantizer):
30+
31+
def broadcasting_filter(self, node: Optional[Node]) -> bool:
32+
"""
33+
Filter function to exclude nodes that perform broadcasting.
34+
"""
35+
if node is None:
36+
return False
37+
if node.target not in [torch.ops.aten.add.Tensor]:
38+
return False
39+
40+
if len(node.all_input_nodes) == 2:
41+
t1 = get_first_fake_tensor(node.all_input_nodes[0])
42+
t2 = get_first_fake_tensor(node.all_input_nodes[1])
43+
return t1.shape != t2.shape
44+
45+
return False
46+
47+
def __init__(self) -> None:
48+
quantizers: List[OperatorConfigQuantizer] = [
49+
OperatorConfigQuantizer(
50+
INT8_BINARY_OPS_OPERATOR_CONFIG, filter_fn=self.broadcasting_filter
51+
),
52+
OperatorConfigQuantizer(INT8_LINEAR_OPERATOR_CONFIG),
53+
]
54+
super().__init__(quantizers)
55+
56+
def validate(self, model: GraphModule) -> bool:
57+
return True
58+
59+
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
60+
pass_manager = CortexMPassManager(None)
61+
return pass_manager.transform_for_annotation(model)
62+
63+
64+
class OperatorConfigQuantizer(Quantizer):
65+
"""
66+
Quantizes a graph according to an OperatorConfig.
67+
68+
Args:
69+
operator_config (OperatorConfig): The operator config to use for quantization.
70+
filter_fn (Callable): Negative filter function. If it returns True on any node in the pattern, the pattern is
71+
skipped. Used to match for example particular targets or modules.
72+
"""
73+
74+
def __init__(
75+
self,
76+
operator_config: QuantizationConfig,
77+
filter_fn: Callable[[Node], bool] = lambda node: False,
78+
) -> None:
79+
self.operator_config = operator_config
80+
self.filter_fn = filter_fn
81+
82+
def check_node(self, node: Optional[Node], target: str) -> bool:
83+
"""
84+
Return true if the node is a valid match for the given target.
85+
"""
86+
if node is None:
87+
return False
88+
if not node.target == target:
89+
return False
90+
if node.meta.get("quantizer_matched", False):
91+
return False
92+
if self.filter_fn(node):
93+
return False
94+
95+
return True
96+
97+
def check_pattern(
98+
self, node: Optional[Node], pattern: List[OpOverload]
99+
) -> Optional[List[Node]]:
100+
"""
101+
Returns the matched nodes if the given node matches the given pattern, otherwise None.
102+
"""
103+
match: List[Node] = []
104+
node = list(node.users)[0] if node and len(node.users) > 0 else None
105+
106+
for pattern_target in pattern:
107+
if self.check_node(node, pattern_target):
108+
match.append(node)
109+
node = list(node.users)[0] if len(node.users) > 0 else None
110+
else:
111+
return None
112+
113+
return match
114+
115+
def match_patterns(
116+
self, model: GraphModule, patterns: List[List[str]]
117+
) -> List[List[Node]]:
118+
"""
119+
Match all given patterns in the graph and return list of matches.
120+
Each node can only be part of one match, larger patterns are prioritized.
121+
Currently only linear patterns (single chain) are supported.
122+
"""
123+
patterns.sort(key=len, reverse=True)
124+
matches: List[List[Node]] = []
125+
for pattern in patterns:
126+
for node in model.graph.nodes:
127+
potential_match = self.check_pattern(node, pattern)
128+
if potential_match:
129+
matches.append(potential_match)
130+
for node in potential_match:
131+
node.meta["quantizer_matched"] = True
132+
133+
return matches
134+
135+
def is_parameter(self, node: Node, model: GraphModule) -> bool:
136+
"""Returns True if the given node is a parameter of the model."""
137+
try:
138+
_ = model.get_parameter(node.target)
139+
return True
140+
except Exception:
141+
return False
142+
143+
def is_weight(self, node: Node, params: List[Node], model: GraphModule) -> bool:
144+
"""Returns True if node is the first parameter of the given parameters"""
145+
return len(params) > 0 and node == params[0]
146+
147+
def is_bias(self, node: Node, params: List[Node], model: GraphModule) -> bool:
148+
"""Returns True if node is the second parameter of the given parameters"""
149+
return len(params) == 2 and node == params[1]
150+
151+
def annotate_match(
152+
self, match: List[Node], config: QuantizationConfig, model: GraphModule
153+
) -> None:
154+
"""
155+
Annotates a matched pattern according to the given quantization config. The
156+
following assumptions are made:
157+
158+
- All operators have either no parameters, only weights, or weights and biases
159+
- Tensors which are the first parameter of an operator are annotated as weights
160+
- Tensors which are the second parameter of an operator are annotated as biases
161+
- All other tensors going into the matched pattern are annotated as input activations.
162+
- All other outputs coming out of the matched pattern are annotated as output activations.
163+
164+
"""
165+
for node in match:
166+
input_qspec_map = {}
167+
output_qspec = None
168+
169+
params = [n for n in node.all_input_nodes if self.is_parameter(n, model)]
170+
# Check that the assumptions on number of parameters hold to avoid silent errors
171+
assert (
172+
0 <= len(params) <= 2
173+
), f"{self.__class__.__name__} expected 0 params, 1 params (weight) or 2 params (weight, bias), but got {len(params)} for node {node}."
174+
175+
for input_node in node.all_input_nodes:
176+
if self.is_weight(input_node, params, model):
177+
input_qspec_map[input_node] = config.weight if config else None
178+
elif self.is_bias(input_node, params, model):
179+
# Bias qspec is derived from input + weight qspecs
180+
input_qspec_map[input_node] = config.bias(node) if config else None
181+
elif input_node not in match:
182+
input_qspec_map[input_node] = (
183+
config.input_activation if config else None
184+
)
185+
186+
if all(node not in match for node in node.users):
187+
output_qspec = config.output_activation if config else None
188+
189+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
190+
input_qspec_map, output_qspec
191+
)
192+
193+
def annotate(self, model: GraphModule) -> None:
194+
matches = self.match_patterns(model, self.operator_config.operators)
195+
for match in matches:
196+
self.annotate_match(match, self.operator_config.config, model)
197+
198+
def validate(self, model: GraphModule) -> bool:
199+
return True

0 commit comments

Comments
 (0)