Skip to content

Commit d7d85fb

Browse files
author
Github Executorch
committed
Summary: MobileNetV2 Fully Lowered to CMSIS-NN
Cortex-M: Enable full MobileNetV2 lowering to CMSIS-NN backend This PR enables end-to-end export of MobileNetV2 to the CMSIS-NN backend for Cortex-M targets. All quantized operations (conv2d, depthwise conv2d, linear/addmm, activations) are now properly lowered to cortex_m::quantized_* operators, enabling efficient inference on resource-constrained microcontrollers Test Plan: python3 -m examples.arm.aot_arm_compiler -m mv2 --target=cortex-m --quantize --intermediates=./mv2_intermediates --output=./mv2_cortex_m.pte cat ./mv2_intermediates/delegation_info.txt Delegation info: Total delegated subgraphs: 0 Number of delegated nodes: 0 Number of non-delegated nodes: 72 Delegation table: ╒════╤═════════════════════════════════════════════╤═══════════════════════════════════╤═══════════════════════════════════════╕ │ │ op_type │ occurrences_in_delegated_graphs │ occurrences_in_non_delegated_graphs │ ╞════╪═════════════════════════════════════════════╪═══════════════════════════════════╪═══════════════════════════════════════╡ │ 0 │ aten_as_strided_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 1 │ aten_mean_dim │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 2 │ aten_view_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 3 │ cortex_m_dequantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 4 │ cortex_m_quantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 5 │ cortex_m_quantized_add_default │ 0 │ 10 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 6 │ cortex_m_quantized_conv2d_default │ 0 │ 35 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 7 │ cortex_m_quantized_depthwise_conv2d_default │ 0 │ 17 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 8 │ cortex_m_quantized_linear_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 9 │ dim_order_ops__clone_dim_order_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 10 │ Total │ 0 │ 71 │ ╘════╧═════════════════════════════════════════════╧═══════════════════════════════════╧═══════════════════════════════════════╛ Note E2E Inference tested on Alif E8 Board Reviewers: Subscribers: Tasks: Tags:
1 parent f48a600 commit d7d85fb

File tree

6 files changed

+490
-29
lines changed

6 files changed

+490
-29
lines changed

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333
from torch.fx import GraphModule, Node
3434

3535

36+
# Passthrough ops that preserve quantization parameters from input to output.
37+
# These ops should be foldable even without explicit annotation metadata.
38+
PASSTHROUGH_OPS = {
39+
exir_ops.edge.aten.hardtanh.default,
40+
exir_ops.edge.aten.relu.default,
41+
exir_ops.edge.aten.clamp.default,
42+
}
43+
3644
def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
3745
if qspec.dtype == torch.int8:
3846
if qspec.qmax == 7 and qspec.qmin == -7:
@@ -248,6 +256,26 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
248256
submodule.graph.erase_node(node_to_remove)
249257
return
250258

259+
@staticmethod
260+
def _has_dq_input_and_q_output(node: Node) -> bool:
261+
"""
262+
Check if a node has dequantize input(s) and quantize output(s).
263+
This indicates the node is part of a quantized computation path.
264+
"""
265+
# Check if any input is from a dequantize op
266+
has_dq_input = any(
267+
isinstance(arg, Node) and arg.target in DQ_OPS
268+
for arg in node.args
269+
if isinstance(arg, Node)
270+
)
271+
272+
# Check if any output goes to a quantize op
273+
has_q_output = any(
274+
user.target in Q_OPS
275+
for user in node.users
276+
)
277+
return has_dq_input and has_q_output
278+
251279
@staticmethod
252280
def is_foldable(node: Node) -> bool:
253281
if node.op != "call_function":
@@ -263,6 +291,13 @@ def is_foldable(node: Node) -> bool:
263291
):
264292
return True
265293

294+
# Passthrough ops (hardtanh, relu, clamp) that have dq inputs and q outputs
295+
# should be foldable even without explicit annotation. These ops preserve
296+
# quantization parameters and are common in quantized models like MobileNetV2.
297+
if node.target in PASSTHROUGH_OPS:
298+
if FoldAndAnnotateQParamsPass._has_dq_input_and_q_output(node):
299+
return True
300+
266301
# We should not fold q-dq nodes into non-quantized nodes.
267302
if not (
268303
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
@@ -335,6 +370,35 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
335370
):
336371
self._handle_control_flow_node(n, graph_module)
337372

373+
# Second pass: Propagate qparams through passthrough ops.
374+
# For ops like hardtanh that share qparams with their input, we need to:
375+
# 1. Copy output_qparams from the passthrough op to its input node
376+
# 2. Set input_qparams on the passthrough op
377+
for n in graph_module.graph.nodes:
378+
n = cast(Node, n)
379+
if n.target not in PASSTHROUGH_OPS:
380+
continue
381+
382+
# Check if this passthrough op has output_qparams but missing input_qparams
383+
has_output = "output_qparams" in n.meta and len(n.meta.get("output_qparams", {})) > 0
384+
has_input = "input_qparams" in n.meta and len(n.meta.get("input_qparams", {})) > 0
385+
386+
if not has_output or has_input:
387+
continue
388+
389+
# Get the input node
390+
if len(n.args) == 0 or not isinstance(n.args[0], Node):
391+
continue
392+
393+
input_node = n.args[0]
394+
395+
# Propagate: For passthrough ops, output qparams equal input qparams
396+
if "output_qparams" not in input_node.meta:
397+
input_node.meta["output_qparams"] = n.meta["output_qparams"]
398+
399+
# Set input_qparams from output_qparams (same for passthrough ops)
400+
n.meta["input_qparams"] = {0: n.meta["output_qparams"][0]}
401+
338402
# retrace the graph to update the fake tensor types
339403
graph_module = super().call(graph_module).graph_module
340404

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,117 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node):
6969
pass
7070
return None
7171

72+
def _get_addmm_replacement(self, node):
73+
"""
74+
Handle aten.addmm (decomposed linear):
75+
addmm(bias, input, weight.T) = input @ weight.T + bias
76+
77+
input_qparams indices for addmm:
78+
[0] = bias (int32)
79+
[1] = input activation (int8)
80+
[2] = weight (int8)
81+
82+
The weight qparams at index [2] are guaranteed to be present because
83+
CortexMQuantizer marks weight nodes as annotated, allowing
84+
FoldAndAnnotateQParamsPass to properly fold Q/DQ nodes and populate qparams.
85+
"""
86+
# Validate addmm node structure: addmm(bias, input, weight)
87+
if len(node.args) < 3:
88+
return None
89+
90+
bias_node = node.args[0]
91+
input_node = node.args[1]
92+
weights_node = node.args[2]
93+
94+
# Validate qparams are present with helpful error messages
95+
input_qparams = node.meta.get("input_qparams", {})
96+
if 1 not in input_qparams:
97+
raise RuntimeError(
98+
f"Missing input activation qparams at index 1 for addmm node '{node.name}'. "
99+
f"Available input_qparams keys: {list(input_qparams.keys())}. "
100+
"Ensure the model is properly quantized and FoldAndAnnotateQParamsPass ran."
101+
)
102+
if 2 not in input_qparams:
103+
raise RuntimeError(
104+
f"Missing weight qparams at index 2 for addmm node '{node.name}'. "
105+
f"Available input_qparams keys: {list(input_qparams.keys())}. "
106+
"Ensure CortexMQuantizer marked weight nodes and PropagateQParamsPass "
107+
"propagated qparams through any transpose/permute ops."
108+
)
109+
110+
# Get input activation qparams (index 1, not 0 which is bias!)
111+
input_scale = input_qparams[1].scale
112+
input_zp = input_qparams[1].zp
113+
114+
# Get weight qparams (index 2)
115+
weight_scale = input_qparams[2].scale
116+
weight_zp = input_qparams[2].zp
117+
118+
# Get output qparams
119+
output_scale = node.meta["output_qparams"][0].scale
120+
output_zp = node.meta["output_qparams"][0].zp
121+
output_min = node.meta["output_qparams"][0].qmin
122+
output_max = node.meta["output_qparams"][0].qmax
123+
124+
# Calculate quantization multiplier and shift
125+
quantized_multiplier, quantized_shift = quantize_multiplier_aot(
126+
(input_scale * weight_scale) / output_scale
127+
)
128+
129+
# Get the original weight tensor
130+
# Trace back through transpose/permute to find the placeholder
131+
if weights_node.op == "call_function" and len(weights_node.args) > 0:
132+
original_weight_node = weights_node.args[0]
133+
else:
134+
original_weight_node = weights_node
135+
136+
weights_tensor = get_param_tensor(self.exported_program, original_weight_node)
137+
final_weights = weights_tensor.contiguous()
138+
139+
# Compute kernel_sum WITHOUT bias (pass None)
140+
# Bias is passed separately to the C++ operator
141+
kernel_sum_tensor = self._compute_kernel_sum(
142+
final_weights, None, -input_zp, -weight_zp
143+
)
144+
145+
# Create placeholders for weights and kernel_sum
146+
with node.graph.inserting_after(original_weight_node):
147+
weights_placeholder = create_constant_placeholder(
148+
self.exported_program,
149+
node.graph,
150+
node.name + "_weights",
151+
InputKind.PARAMETER,
152+
final_weights,
153+
)
154+
155+
kernel_sum = create_constant_placeholder(
156+
self.exported_program,
157+
node.graph,
158+
node.name + "_kernel_sum",
159+
InputKind.PARAMETER,
160+
kernel_sum_tensor,
161+
)
162+
163+
# Build args for cortex_m.quantized_linear
164+
args = (
165+
input_node,
166+
weights_placeholder,
167+
bias_node,
168+
kernel_sum,
169+
-input_zp,
170+
-weight_zp,
171+
output_zp,
172+
[quantized_multiplier],
173+
[quantized_shift],
174+
output_max,
175+
output_min,
176+
)
177+
178+
return exir_ops.edge.cortex_m.quantized_linear.default, args
179+
72180
def _get_linear_replacement(self, node):
73181
"""
74-
Let
182+
Let
75183
- yi be the output activations (y1, ... yn)
76184
- xj be the input activations (x1, ... xm)
77185
- wij be the weights (w11, ... wnm)
@@ -386,6 +494,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
386494
match node.target:
387495
case exir_ops.edge.aten.linear.default:
388496
op, args = self._get_linear_replacement(node)
497+
case exir_ops.edge.aten.addmm.default:
498+
result = self._get_addmm_replacement(node)
499+
if result is None:
500+
continue
501+
op, args = result
389502
case exir_ops.edge.aten.convolution.default:
390503
# Check if it's transposed convolution (arg index 6)
391504
transposed = node.args[6] if len(node.args) > 6 else False

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
from executorch.backends.transforms.replace_scalar_with_tensor import (
1515
ReplaceScalarWithTensorArgPass,
1616
)
17+
from executorch.backends.arm._passes.decompose_adaptive_avg_pool2d_pass import (
18+
DecomposeAdaptiveAvgPool2dPass,
19+
)
20+
from executorch.backends.cortex_m.passes.propagate_qparams_pass import (
21+
PropagateQParamsPass,
22+
)
1723
from executorch.exir.pass_base import ExportPass
1824
from executorch.exir.pass_manager import PassManager
1925
from executorch.exir.program._program import _transform
@@ -34,9 +40,11 @@ class CortexMPassManager(PassManager):
3440
# Run before folding so qparams attach to max_pool2d values, not tuple + getitem.
3541
RemoveGetItemPass,
3642
FoldAndAnnotateQParamsPass,
43+
PropagateQParamsPass,
3744
ReplaceScalarWithTensorArgPass,
3845
ReplaceQuantNodesPass,
3946
ActivationFusionPass,
47+
DecomposeAdaptiveAvgPool2dPass,
4048
DecomposeHardswishPass,
4149
QuantizedOpFusionPass,
4250
ConvertToCortexMPass,
@@ -49,12 +57,22 @@ class CortexMPassManager(PassManager):
4957
DecomposeMeanPass,
5058
]
5159

52-
def __init__(self, exported_program, passes=None):
60+
def __init__(self, exported_program, passes=None, skip_passes=None):
61+
"""
62+
Initialize CortexMPassManager.
63+
64+
Args:
65+
exported_program: The ExportedProgram to transform.
66+
passes: Optional custom pass list. Uses default pass_list if None.
67+
skip_passes: Optional list of pass classes to skip.
68+
"""
5369
self.exported_program = exported_program
5470
if passes is not None:
5571
self.passes = passes
5672
else:
57-
self.passes = self.pass_list
73+
self.passes = list(self.pass_list)
74+
if skip_passes:
75+
self.passes = [p for p in self.passes if p not in skip_passes]
5876

5977
def transform_for_annotation(self, model):
6078
passes = self.pass_list_transform_for_annotation

0 commit comments

Comments
 (0)