Skip to content

Commit cf63bbc

Browse files
Add Move activation before concat pass, Concat cluster quantization
1 parent 72333b0 commit cf63bbc

10 files changed

+1283
-46
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
9+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
10+
11+
from torch.fx import GraphModule, Node
12+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
13+
14+
15+
class MoveActivationBeforeConcat(PassBase):
16+
"""Move some operators around in the following pattern.
17+
This is a common pattern that emerges from the conversion of separable convolutions.
18+
This optimization works together with joint quantization of compute nodes and activations. Without it,
19+
it is not beneficial.
20+
21+
│ │ │ │
22+
┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐
23+
│ aten.conv2d │ ... │ aten.conv2d │ │ aten.conv2d │ ... │ aten.conv2d │
24+
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘
25+
└───────┐ ┌──────┘ │ │
26+
┌──▼─────▼─┐ replace with ┌─────▼─────┐ ┌─────▼─────┐
27+
│ aten.cat │ ──────────────► │ aten.relu │ ... │ aten.relu │
28+
└────┬─────┘ └─────┬─────┘ └─────┬─────┘
29+
│ └───────┐ ┌───────┘
30+
┌─────▼─────┐ ┌──▼─────▼─┐
31+
│ aten.relu │ │ aten.cat │
32+
└─────┬─────┘ └────┬─────┘
33+
│ │
34+
"""
35+
36+
def __init__(self, neutron_target_spec: NeutronTargetSpec):
37+
self.neutron_target_spec = neutron_target_spec
38+
39+
def call(self, module: GraphModule) -> bool:
40+
def _is_concat(node_: Node) -> bool:
41+
return (
42+
node_.op == "call_function"
43+
and node_.target == torch.ops.aten.cat.default
44+
)
45+
46+
made_changes = False
47+
48+
for node in module.graph.nodes:
49+
if not _is_concat(node):
50+
continue # Not cat node.
51+
52+
cat_node = node
53+
activation = next(iter(cat_node.users))
54+
55+
# Check if all cat inputs nodes are conv 2D or linear 2D type and their only user is cat.
56+
if not all(
57+
self.neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten(
58+
input_node
59+
)
60+
and len(input_node.users) == 1
61+
for input_node in cat_node.all_input_nodes
62+
):
63+
continue
64+
65+
# Check if following activation is supported on Neutron as fused activation.
66+
if not (
67+
len(cat_node.users) == 1
68+
and self.neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten(
69+
activation
70+
)
71+
):
72+
continue
73+
74+
# Loop all Cat input nodes and insert new activation after node.
75+
for input_node in cat_node.all_input_nodes:
76+
with module.graph.inserting_after(input_node):
77+
new_activation = module.graph.call_function(
78+
activation.target,
79+
args=(*activation.args[1:],),
80+
kwargs=activation.kwargs,
81+
)
82+
83+
new_activation.meta["source_fn_stack"] = [
84+
(
85+
new_activation.name,
86+
activation.meta["source_fn_stack"][-1][-1],
87+
)
88+
]
89+
new_activation.meta["val"] = input_node.meta["val"]
90+
91+
# Replace the uses of the input node with the new activation node.
92+
input_node.replace_all_uses_with(new_activation)
93+
new_activation.args = (input_node, *new_activation.args)
94+
95+
# Replace the uses of the activation node with the cat node.
96+
activation.replace_all_uses_with(cat_node)
97+
98+
module.graph.erase_node(activation)
99+
100+
made_changes = True
101+
102+
return PassResult(module, made_changes)

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from executorch.backends.nxp.aten_passes.fuse_linear_and_add_pass import (
1717
FuseLinearAndAddPass,
1818
)
19+
from executorch.backends.nxp.aten_passes.move_activation_before_concat import (
20+
MoveActivationBeforeConcat,
21+
)
1922
from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import (
2023
RemoveNodesWithKnownOutputs,
2124
)
@@ -25,6 +28,7 @@
2528
from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import (
2629
SplitGRUBasedOnNumLayers,
2730
)
31+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2832
from executorch.exir.pass_manager import PassManager
2933
from torch import nn
3034
from torch.fx.passes.infra.pass_base import PassResult
@@ -34,14 +38,17 @@
3438

3539
class NeutronAtenPassManager(PassManager):
3640

37-
def __init__(self, passes: list[PassType] = None):
41+
def __init__(
42+
self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None
43+
):
3844
passes: list[PassType] = passes or [
3945
FuseBatchNormWithConvPass(),
4046
FuseBatchNormWithLinearPass(),
4147
SplitGroupConvolution(),
4248
SplitGRUBasedOnNumLayers(),
4349
RemoveNodesWithKnownOutputs(),
4450
FuseLinearAndAddPass(),
51+
MoveActivationBeforeConcat(neutron_target_spec),
4552
]
4653

4754
super().__init__(passes)

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1313
from executorch.backends.nxp.quantizer.patterns import (
1414
AbsPattern,
15+
ActivationsConcatClusterPattern,
1516
AdaptiveAvgPoolPattern,
1617
AddmmPattern,
1718
AddTensorPattern,
@@ -225,13 +226,16 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
225226
self.op_to_applied_quantizer = {
226227
pt: False for q in self.quantizers for pt in q.pattern.partition_types()
227228
}
229+
self.cluster_quantizers = [
230+
NeutronAtenQuantizer(ActivationsConcatClusterPattern(self), static_qconfig)
231+
]
228232

229233
def transform_for_annotation(
230234
self, model: torch.fx.GraphModule
231235
) -> torch.fx.GraphModule:
232236
model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes.
233237

234-
model = NeutronAtenPassManager()(model).graph_module
238+
model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module
235239

236240
model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes.
237241

@@ -240,6 +244,10 @@ def transform_for_annotation(
240244
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
241245
self._annotate_inputs(model)
242246

247+
# Annotate node clusters in model
248+
for cluster_quantizer in self.cluster_quantizers:
249+
cluster_quantizer.annotate(model)
250+
243251
nodes = list(model.graph.nodes)
244252
for node in nodes:
245253
if (

backends/nxp/quantizer/patterns.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
1414
from torch import fx
1515
from torch._ops import OpOverload
16+
from torch.fx import Node
1617
from torchao.quantization.pt2e import PerChannelMinMaxObserver
1718
from torchao.quantization.pt2e.quantizer import (
1819
DerivedQuantizationSpec,
1920
FixedQParamsQuantizationSpec,
2021
QuantizationSpec,
2122
SharedQuantizationSpec,
2223
)
24+
2325
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2426

2527

@@ -199,7 +201,6 @@ def partition_types(self) -> list[OpOverload]:
199201
def get_anchors(
200202
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
201203
) -> PartitionAnchors:
202-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
203204
addmm_node = fused_partition[0].nodes[-1]
204205

205206
bias_qspec = DerivedQuantizationSpec(
@@ -745,3 +746,147 @@ def get_anchors(
745746
return get_anchors_for_fixed_quant_specs(
746747
fused_partition, scale=1.0 / 128.0, zero_point=0
747748
)
749+
750+
751+
class ActivationsConcatClusterPattern(QuantizationPattern):
752+
"""
753+
Quantizer for activations concat cluster pattern.
754+
755+
The quantizer matches a pattern where concat node is preceded by activation nodes preceded by Conv 2D or Linear.
756+
All activation nodes quantization parameters must be the same. Only activations, that have support for fusion
757+
to preceding compute node on Neutron are allowed. This cluster is usually produced by MoveActivationBeforeConcat
758+
pass. Cluster schema:
759+
760+
│ │
761+
┌──────▼──────┐ ┌──────▼──────┐
762+
│ aten.conv2d │ ... │ aten.conv2d │
763+
└──────┬──────┘ └──────┬──────┘
764+
│ │
765+
┌─────▼─────┐ ┌─────▼─────┐
766+
│ aten.relu │ ... │ aten.relu │
767+
└─────┬─────┘ └─────┬─────┘
768+
└───────┐ ┌───────┘
769+
┌──▼─────▼─┐
770+
│ aten.cat │
771+
└────┬─────┘
772+
773+
"""
774+
775+
def __init__(self, neutron_quantizer):
776+
self.neutron_quantizer = neutron_quantizer
777+
self.neutron_target_info = (
778+
self.neutron_quantizer.neutron_target_spec.neutron_target_info
779+
)
780+
781+
@staticmethod
782+
def _all_activations_are_equal(activations: list[Node]) -> bool:
783+
first_input_node = activations[0]
784+
hardtanh_t = [
785+
torch.ops.aten.hardtanh.default,
786+
torch.ops.aten.hardtanh_.default,
787+
]
788+
relu_t = [
789+
torch.ops.aten.relu.default,
790+
torch.ops.aten.relu_.default,
791+
]
792+
tanh_t = [
793+
torch.ops.aten.tanh.default,
794+
torch.ops.aten.tanh_.default,
795+
]
796+
797+
def _activations_are_equal(activation1: Node, activation2: Node) -> bool:
798+
if ( # Targets are equal also with their inplace variants
799+
activation1.target in hardtanh_t
800+
and activation2.target in hardtanh_t
801+
or activation1.target in relu_t
802+
and activation2.target in relu_t
803+
or activation1.target in tanh_t
804+
and activation2.target in tanh_t
805+
or activation1.target == torch.ops.aten.sigmoid.default
806+
and activation2.target == torch.ops.aten.sigmoid.default
807+
):
808+
return True
809+
elif ( # Hardtanh with min_val 0 and max_val 'inf' is equal to Relu
810+
activation1.target in hardtanh_t
811+
and activation1.args[1:] == (0.0, float("inf"))
812+
and activation2.target in relu_t
813+
or activation1.target in relu_t
814+
and activation2.target in hardtanh_t
815+
and activation2.args[1:] == (0.0, float("inf"))
816+
):
817+
return True
818+
else:
819+
return False
820+
821+
return all(
822+
_activations_are_equal(activation, first_input_node)
823+
for activation in activations
824+
)
825+
826+
def partition_types(self) -> list[OpOverload]:
827+
return [torch.ops.aten.cat.default]
828+
829+
def get_anchors(
830+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
831+
) -> PartitionAnchors | None:
832+
cat_node = fused_partition[0].nodes[-1]
833+
834+
# Check all cat inputs are supported activations
835+
if not all(
836+
self.neutron_target_info.is_supported_fused_activation__aten(input_node)
837+
for input_node in cat_node.all_input_nodes
838+
):
839+
return None
840+
841+
# Check all cat inputs are equal activations
842+
if not self._all_activations_are_equal(cat_node.all_input_nodes):
843+
return None
844+
845+
# Check compute nodes are Conv 2D or Linear
846+
if not all(
847+
self.neutron_target_info.is_fusable_conv_or_linear__aten(compute_node)
848+
for input_node in cat_node.all_input_nodes
849+
for compute_node in input_node.all_input_nodes
850+
):
851+
return None
852+
853+
# Annotate compute nodes
854+
for input_node in cat_node.all_input_nodes:
855+
for compute_node in input_node.all_input_nodes:
856+
if compute_node.target not in self.neutron_quantizer.op_to_quantizer:
857+
return None
858+
compute_node_quantizer = self.neutron_quantizer.op_to_quantizer[
859+
compute_node.target
860+
]
861+
compute_node_quantizer.annotate(gm)
862+
del compute_node.meta["quantization_annotation"].output_qspec
863+
864+
# Annotate activations
865+
for input_node in cat_node.all_input_nodes:
866+
if input_node.target not in self.neutron_quantizer.op_to_quantizer:
867+
return None
868+
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
869+
input_node.target
870+
]
871+
activation_quantizer.annotate(gm)
872+
input_node.meta["quantization_annotation"].input_qspec_map = {}
873+
874+
# Annotate cat node
875+
inputs = []
876+
first_input_node = cat_node.all_input_nodes[0]
877+
for idx in range(len(cat_node.all_input_nodes)):
878+
inputs.append(
879+
(
880+
cat_node,
881+
NodeArgsIdx(0, idx),
882+
SharedQuantizationSpec(first_input_node),
883+
)
884+
)
885+
outputs = [(cat_node, SharedQuantizationSpec(first_input_node))]
886+
887+
return PartitionAnchors(
888+
inputs=inputs,
889+
weights=[],
890+
biases=[],
891+
output=outputs,
892+
)

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.view_copy_converter import (
1919
ViewCopyConverter,
2020
)
21-
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
21+
from executorch.backends.nxp.tests.executorch_pipeline import (
22+
neutron_target_spec,
23+
to_quantized_edge_program,
24+
)
2225
from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck
2326
from torch import nn
2427

@@ -98,7 +101,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
98101
program = torch.export.export(module, example_input, strict=True)
99102
og_module = program.module()
100103

101-
pm = NeutronAtenPassManager()
104+
pm = NeutronAtenPassManager(neutron_target_spec)
102105
graph_module_out = pm(deepcopy(program.module())).graph_module
103106

104107
# Make sure the fusion worked.
@@ -133,7 +136,7 @@ def test_batch_norm_linear_fusing(bias: bool):
133136
program = torch.export.export(module, example_input, strict=True)
134137
og_module = program.module()
135138

136-
pm = NeutronAtenPassManager()
139+
pm = NeutronAtenPassManager(neutron_target_spec)
137140
graph_module_out = pm(deepcopy(program.module())).graph_module
138141

139142
# Make sure the fusion worked.

0 commit comments

Comments
 (0)