Skip to content

Commit 10ce754

Browse files
MartinPavellaStrycekSimon
authored andcommitted
NXP backend: Add pre-processing pass to fuse Lienar + Add (pytorch#14112)
### Summary Add a pre-processing aten dialect pass, which fuses Linear nodes with following Add nodes. This pass replaces the existing Neutron IR optimization. ### Test plan Unit tests provided. cc @robert-kalmar
1 parent c695c9c commit 10ce754

File tree

6 files changed

+853
-88
lines changed

6 files changed

+853
-88
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional
7+
8+
import torch
9+
10+
from executorch.backends.nxp.backend.edge_helper import (
11+
try_get_tensor_constant_from_node,
12+
)
13+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
14+
from torch.export.unflatten import _assign_attr, _AttrKind
15+
from torch.fx import GraphModule, Node
16+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
17+
18+
19+
class FuseLinearAndAddPass(PassBase):
20+
"""Replace a sequence of `linear` and `add` nodes in the following pattern by a single `linear` node when possible.
21+
22+
┌──────▼──────┐
23+
│ aten.linear │
24+
└──────┬──────┘ │
25+
│ replace with ┌──────▼──────┐
26+
┌─────▼────┐ ───────────► │ aten.linear │
27+
│ aten.add │ └──────┬──────┘
28+
└─────┬────┘
29+
30+
"""
31+
32+
def _fuse_with_existing_bias(
33+
self,
34+
linear_node: Node,
35+
other_add_input: Node,
36+
graph_module: GraphModule,
37+
alpha: float,
38+
) -> bool:
39+
"""Fuse the `linear` and `add` nodes provided the `linear` already has a bias.
40+
The fusion can only be done if both the "biases" have static data, which can be added together to get a
41+
single bias.
42+
43+
:return: True, if the nodes were successfully merged. False, otherwise.
44+
"""
45+
46+
linear_bias = linear_node.args[2]
47+
if other_add_input.meta["val"].shape != linear_bias.meta["val"].shape:
48+
# The biases cannot be added together due to their different shapes.
49+
# Shape broadcasting is not applicable, as the only allowed `linear` bias shape is 1D ([output_features]).
50+
return False
51+
52+
bias_data = [
53+
try_get_tensor_constant_from_node(graph_module, linear_bias),
54+
try_get_tensor_constant_from_node(graph_module, other_add_input),
55+
]
56+
if any(data is None for data in bias_data):
57+
return (
58+
False # Fusion is not possible because at least 1 bias is not static.
59+
)
60+
61+
# Add the bias data together, to obtain the combined bias. Take the `alpha` attribute into account.
62+
combined_bias = bias_data[0] + bias_data[1] * alpha
63+
64+
# Create a new node containing the combined bias data.
65+
combined_bias_name = get_new_attr_name_with_prefix(
66+
linear_bias.name + "combined"
67+
)(graph_module)
68+
_assign_attr(
69+
torch.nn.Parameter(combined_bias),
70+
graph_module,
71+
combined_bias_name,
72+
_AttrKind.PARAMETER,
73+
)
74+
with graph_module.graph.inserting_before(linear_node):
75+
new_bias_node = graph_module.graph.get_attr(combined_bias_name)
76+
77+
# Use the combined bias as the new bias for the `Linear`.
78+
linear_node.args = (
79+
linear_node.args[:2] + (new_bias_node,) + linear_node.args[3:]
80+
)
81+
return True
82+
83+
def _fuse_without_existing_bias(
84+
self,
85+
linear_node: Node,
86+
other_add_input: Node,
87+
graph_module: GraphModule,
88+
alpha: float,
89+
) -> bool:
90+
"""Fuse the `linear` and `add` provided the `linear` does not already have a bias.
91+
92+
:return: True, if the nodes were successfully merged. False, otherwise.
93+
"""
94+
95+
# The weights have shape (out_features, in_features).
96+
output_features = linear_node.args[1].meta["val"].shape[0]
97+
new_bias_shape = other_add_input.meta["val"].shape
98+
if list(new_bias_shape) != [output_features]:
99+
return False # The `Add` is adding a tensor with shape that is not supported for the `Linear` bias.
100+
101+
bias_data = try_get_tensor_constant_from_node(graph_module, other_add_input)
102+
103+
if bias_data is None:
104+
return False # Neutron doesn't support a dynamic bias, so fusion would be counterproductive.
105+
106+
# It is possible that the `linear` comes before the `other_add_input` in the graph, so it cannot use it as an
107+
# input directly. If the nodes are ordered as [linear, ..., other_add_input, ... add] (which is valid), using
108+
# `other_add_input` directly as an input to `Linear` would not follow topological order.
109+
# Rearranging the nodes is not trivial, as the graph could be complex (ultimately, the
110+
# `other_add_input` could even originate from the `Linear` node...).
111+
# Since the `other_add_input` has static data, we can create a new node with the data just before the `Linear`
112+
# to ensure topological order.
113+
# Regardless of the node ordering, the `add.Tensor` attribute `alpha` multiplies the second `add` input. If
114+
# `alpha != 1`, we would have to insert a `mul` operator if we wanted to keep the original parameter node.
115+
# Therefore, it is better to create a new static parameter node for the multiplied data in this case as well.
116+
nodes = list(graph_module.graph.nodes)
117+
if nodes.index(linear_node) < nodes.index(other_add_input) or alpha != 1.0:
118+
# Problematic order, or required multiplication.
119+
120+
# Handle the `aten.add.Tensor` attribute `alpha`.
121+
bias_data *= alpha
122+
123+
# Create a unique name.
124+
new_bias_name = get_new_attr_name_with_prefix(linear_node.name + "_bias")(
125+
graph_module
126+
)
127+
_assign_attr(bias_data, graph_module, new_bias_name, _AttrKind.PARAMETER)
128+
with graph_module.graph.inserting_before(linear_node):
129+
new_bias_node = graph_module.graph.get_attr(new_bias_name)
130+
131+
# Use the added tensor as the new `Linear` bias.
132+
linear_node.args = (
133+
linear_node.args[:2] + (new_bias_node,) + linear_node.args[2:]
134+
)
135+
return True
136+
137+
else:
138+
# Use the `other_add_input` directly as the new bias.
139+
linear_node.args = (
140+
linear_node.args[:2] + (other_add_input,) + linear_node.args[2:]
141+
)
142+
return True
143+
144+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
145+
def _is_applicable_linear_node(node_: Node):
146+
is_linear = (
147+
node_.op == "call_function"
148+
and node_.target == torch.ops.aten.linear.default
149+
)
150+
has_single_user = len(node.users) == 1
151+
152+
return is_linear and has_single_user
153+
154+
def _is_add(node_: Node):
155+
return (
156+
node_.op == "call_function"
157+
and node_.target == torch.ops.aten.add.Tensor
158+
)
159+
160+
made_changes = False
161+
for node in graph_module.graph.nodes:
162+
if not _is_applicable_linear_node(
163+
linear_node := node
164+
): # Also ensures a single user.
165+
continue
166+
167+
if not _is_add(add_node := list(linear_node.users.keys())[0]):
168+
continue # Not the `Linear` -> `Add` case.
169+
170+
if len(add_node.args) != 2:
171+
continue # Unexpected case.
172+
173+
# The `aten.add.Tensor` carries out the expression `out = input[0] + alpha × input[1]`.
174+
# https://docs.pytorch.org/docs/stable/generated/torch.add.html
175+
alpha = add_node.kwargs.get("alpha", 1.0)
176+
if add_node.args[0] == linear_node:
177+
other_add_input = add_node.args[1]
178+
179+
else:
180+
# The fusion is not implemented. The `other_add_input` would have to be divided by `alpha` before the
181+
# fusion, and a `mul` operator would have to be added after the `linear` to multiply its output by
182+
# `alpha`.
183+
continue
184+
185+
if len(linear_node.args) > 2:
186+
if not self._fuse_with_existing_bias(
187+
linear_node, other_add_input, graph_module, alpha
188+
):
189+
continue # The nodes could not be fused.
190+
191+
else:
192+
# The `Linear` doesn't have a bias yet.
193+
if not self._fuse_without_existing_bias(
194+
linear_node, other_add_input, graph_module, alpha
195+
):
196+
continue # The nodes could not be fused.
197+
198+
# Use the output of the `Linear` instead of the `Add`, and remove the now unused `Add` node.
199+
add_node.replace_all_uses_with(linear_node)
200+
graph_module.graph.erase_node(add_node)
201+
202+
made_changes = True
203+
204+
return PassResult(graph_module, made_changes)

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
1414
FuseBatchNormWithLinearPass,
1515
)
16+
from executorch.backends.nxp.aten_passes.fuse_linear_and_add_pass import (
17+
FuseLinearAndAddPass,
18+
)
1619
from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import (
1720
RemoveNodesWithKnownOutputs,
1821
)
@@ -38,6 +41,7 @@ def __init__(self, passes: list[PassType] = None):
3841
SplitGroupConvolution(),
3942
SplitGRUBasedOnNumLayers(),
4043
RemoveNodesWithKnownOutputs(),
44+
FuseLinearAndAddPass(),
4145
]
4246

4347
super().__init__(passes)

backends/nxp/backend/edge_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 NXP
1+
# Copyright 2024-2025 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.

backends/nxp/backend/ir/tflite_optimizer/optimizations/fuse_fully_connected_and_add_operators.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

backends/nxp/backend/ir/tflite_optimizer/optimizer.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.fuse_activation_functions import (
1818
FuseActivationFunctions,
1919
)
20-
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.fuse_fully_connected_and_add_operators import (
21-
FuseFullyConnectedAndAddOperators,
22-
)
2320
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.move_relu_before_concat import (
2421
MoveActivationBeforeConcatenation,
2522
)
@@ -34,7 +31,6 @@
3431

3532
class Optimization(Enum):
3633
FUSE_ACTIVATION_FUNCTIONS = 1
37-
FUSE_FULLY_CONNECTED_AND_ADD = 2
3834

3935
FUSE_TRANSPOSE_OPERATORS = 5
4036
REMOVE_IDENTITY_TRANSPOSE_OPERATORS = 6
@@ -75,9 +71,6 @@ def __init__(
7571
Optimization.FUSE_ACTIVATION_FUNCTIONS: FuseActivationFunctions(
7672
builder, conversion_config
7773
),
78-
Optimization.FUSE_FULLY_CONNECTED_AND_ADD: FuseFullyConnectedAndAddOperators(
79-
builder, conversion_config
80-
),
8174
Optimization.FUSE_TRANSPOSE_OPERATORS: FuseTransposeOperators(
8275
builder, conversion_config
8376
),

0 commit comments

Comments
 (0)