Skip to content

Commit c9ae6d4

Browse files
author
Github Executorch
committed
Summary:MV2 CortexM PassManager changes for Alif E8
Test Plan: python3 -m examples.arm.aot_arm_compiler -m mv2 --target=cortex-m --quantize --enable_qdq_fusion_pass --intermediates=./mv2_intermediates --output=./mv2_cortex_m.pte cat ./mv2_intermediates/delegation_info.txt Delegation info: Total delegated subgraphs: 0 Number of delegated nodes: 0 Number of non-delegated nodes: 72 Delegation table: ╒════╤═════════════════════════════════════════════╤═══════════════════════════════════╤═══════════════════════════════════════╕ │ │ op_type │ occurrences_in_delegated_graphs │ occurrences_in_non_delegated_graphs │ ╞════╪═════════════════════════════════════════════╪═══════════════════════════════════╪═══════════════════════════════════════╡ │ 0 │ aten_as_strided_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 1 │ aten_mean_dim │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 2 │ aten_view_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 3 │ cortex_m_dequantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 4 │ cortex_m_quantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 5 │ cortex_m_quantized_add_default │ 0 │ 10 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 6 │ cortex_m_quantized_conv2d_default │ 0 │ 35 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 7 │ cortex_m_quantized_depthwise_conv2d_default │ 0 │ 17 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 8 │ cortex_m_quantized_linear_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 9 │ cortex_m_transpose_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 10 │ dim_order_ops__clone_dim_order_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 11 │ Total │ 0 │ 72 │ ╘════╧═════════════════════════════════════════════╧═══════════════════════════════════╧═══════════════════════════════════════╛ Reviewers: Subscribers: Tasks: Tags:
1 parent ec4c462 commit c9ae6d4

File tree

5 files changed

+261
-14
lines changed

5 files changed

+261
-14
lines changed

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@
3333
from torch.fx import GraphModule, Node
3434

3535

36+
# Passthrough ops that preserve quantization parameters from input to output.
37+
# These ops should be foldable even without explicit annotation metadata.
38+
PASSTHROUGH_OPS = {
39+
exir_ops.edge.aten.hardtanh.default,
40+
exir_ops.edge.aten.relu.default,
41+
exir_ops.edge.aten.clamp.default,
42+
}
43+
44+
3645
def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
3746
if qspec.dtype == torch.int8:
3847
if qspec.qmax == 7 and qspec.qmin == -7:
@@ -248,6 +257,26 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
248257
submodule.graph.erase_node(node_to_remove)
249258
return
250259

260+
@staticmethod
261+
def _has_dq_input_and_q_output(node: Node) -> bool:
262+
"""
263+
Check if a node has dequantize input(s) and quantize output(s).
264+
This indicates the node is part of a quantized computation path.
265+
"""
266+
# Check if any input is from a dequantize op
267+
has_dq_input = any(
268+
isinstance(arg, Node) and arg.target in DQ_OPS
269+
for arg in node.args
270+
if isinstance(arg, Node)
271+
)
272+
273+
# Check if any output goes to a quantize op
274+
has_q_output = any(
275+
user.target in Q_OPS
276+
for user in node.users
277+
)
278+
return has_dq_input and has_q_output
279+
251280
@staticmethod
252281
def is_foldable(node: Node) -> bool:
253282
if node.op != "call_function":
@@ -263,6 +292,13 @@ def is_foldable(node: Node) -> bool:
263292
):
264293
return True
265294

295+
# Passthrough ops (hardtanh, relu, clamp) that have dq inputs and q outputs
296+
# should be foldable even without explicit annotation. These ops preserve
297+
# quantization parameters and are common in quantized models like MobileNetV2.
298+
if node.target in PASSTHROUGH_OPS:
299+
if FoldAndAnnotateQParamsPass._has_dq_input_and_q_output(node):
300+
return True
301+
266302
# We should not fold q-dq nodes into non-quantized nodes.
267303
if not (
268304
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
@@ -335,6 +371,35 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
335371
):
336372
self._handle_control_flow_node(n, graph_module)
337373

374+
# Second pass: Propagate qparams through passthrough ops.
375+
# For ops like hardtanh that share qparams with their input, we need to:
376+
# 1. Copy output_qparams from the passthrough op to its input node
377+
# 2. Set input_qparams on the passthrough op
378+
for n in graph_module.graph.nodes:
379+
n = cast(Node, n)
380+
if n.target not in PASSTHROUGH_OPS:
381+
continue
382+
383+
# Check if this passthrough op has output_qparams but missing input_qparams
384+
has_output = "output_qparams" in n.meta and len(n.meta.get("output_qparams", {})) > 0
385+
has_input = "input_qparams" in n.meta and len(n.meta.get("input_qparams", {})) > 0
386+
387+
if not has_output or has_input:
388+
continue
389+
390+
# Get the input node
391+
if len(n.args) == 0 or not isinstance(n.args[0], Node):
392+
continue
393+
394+
input_node = n.args[0]
395+
396+
# Propagate: For passthrough ops, output qparams equal input qparams
397+
if "output_qparams" not in input_node.meta:
398+
input_node.meta["output_qparams"] = n.meta["output_qparams"]
399+
400+
# Set input_qparams from output_qparams (same for passthrough ops)
401+
n.meta["input_qparams"] = {0: n.meta["output_qparams"][0]}
402+
338403
# retrace the graph to update the fake tensor types
339404
graph_module = super().call(graph_module).graph_module
340405

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,45 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node):
6969
pass
7070
return None
7171

72+
def _get_addmm_replacement(self, node):
73+
"""
74+
Handle aten.addmm which has signature: addmm(bias, input, weight)
75+
This is the decomposed form of aten.linear.
76+
"""
77+
bias = node.args[0]
78+
input_tensor = node.args[1]
79+
weights = node.args[2]
80+
81+
if "input_qparams" not in node.meta or "output_qparams" not in node.meta:
82+
return None
83+
84+
input_qp = node.meta["input_qparams"].get(0)
85+
weight_qp = node.meta["input_qparams"].get(1)
86+
output_qp = node.meta["output_qparams"].get(0)
87+
88+
if not input_qp or not weight_qp or not output_qp:
89+
return None
90+
91+
quantized_multiplier, quantized_shift = quantize_multiplier_aot(
92+
(input_qp.scale * weight_qp.scale) / output_qp.scale
93+
)
94+
95+
args = (
96+
input_tensor,
97+
weights,
98+
bias,
99+
None,
100+
input_qp.zp,
101+
weight_qp.zp,
102+
output_qp.zp,
103+
[quantized_multiplier],
104+
[quantized_shift],
105+
input_qp.qmax,
106+
input_qp.qmin,
107+
)
108+
109+
return exir_ops.edge.cortex_m.quantized_linear.default, args
110+
72111
def _get_linear_replacement(self, node):
73112
"""
74113
Let
@@ -386,6 +425,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
386425
match node.target:
387426
case exir_ops.edge.aten.linear.default:
388427
op, args = self._get_linear_replacement(node)
428+
case exir_ops.edge.aten.addmm.default:
429+
result = self._get_addmm_replacement(node)
430+
if result is None:
431+
continue
432+
op, args = result
389433
case exir_ops.edge.aten.convolution.default:
390434
# Check if it's transposed convolution (arg index 6)
391435
transposed = node.args[6] if len(node.args) > 6 else False

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from executorch.backends.transforms.replace_scalar_with_tensor import (
1414
ReplaceScalarWithTensorArgPass,
1515
)
16+
from executorch.backends.arm._passes.decompose_adaptive_avg_pool2d_pass import (
17+
DecomposeAdaptiveAvgPool2dPass,
18+
)
1619
from executorch.exir.pass_base import ExportPass
1720
from executorch.exir.pass_manager import PassManager
1821
from executorch.exir.program._program import _transform
@@ -33,6 +36,7 @@ class CortexMPassManager(PassManager):
3336
ReplaceScalarWithTensorArgPass,
3437
ReplaceQuantNodesPass,
3538
ActivationFusionPass,
39+
DecomposeAdaptiveAvgPool2dPass,
3640
DecomposeHardswishPass,
3741
QuantizedOpFusionPass,
3842
ConvertToCortexMPass,
@@ -44,12 +48,22 @@ class CortexMPassManager(PassManager):
4448
ClampHardswishPass,
4549
]
4650

47-
def __init__(self, exported_program, passes=None):
51+
def __init__(self, exported_program, passes=None, skip_passes=None):
52+
"""
53+
Initialize CortexMPassManager.
54+
55+
Args:
56+
exported_program: The ExportedProgram to transform.
57+
passes: Optional custom pass list. Uses default pass_list if None.
58+
skip_passes: Optional list of pass classes to skip.
59+
"""
4860
self.exported_program = exported_program
4961
if passes is not None:
5062
self.passes = passes
5163
else:
52-
self.passes = self.pass_list
64+
self.passes = list(self.pass_list)
65+
if skip_passes:
66+
self.passes = [p for p in self.passes if p not in skip_passes]
5367

5468
def transform_for_annotation(self, model):
5569
passes = self.pass_list_transform_for_annotation

backends/cortex_m/quantizer/quantizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,11 @@ class SharedQspecQuantizer(Quantizer):
448448
torch.ops.aten._unsafe_view.default,
449449
torch.ops.aten.unflatten.int,
450450
torch.ops.aten.flatten.using_ints,
451+
# Additional passthrough ops for MobileNetV2 and similar architectures
452+
torch.ops.aten.hardtanh.default,
453+
torch.ops.aten.hardtanh_.default,
454+
torch.ops.aten.max_pool2d.default,
455+
torch.ops.aten.dropout.default,
451456
]
452457

453458
def __init__(self, targets: Optional[List[OpOverload]] = None) -> None:

0 commit comments

Comments
 (0)