Skip to content

Commit 396d722

Browse files
NXP backend: Add pre-processing pass to replace nodes with known outputs with their data. (#13937)
### Summary Sometimes nodes will always produce the same data at runtime, regardless of the model inputs. Add a pre-processing aten dialect pass which identifies these cases and replaces the nodes with their static data. This allows for more advanced graph optimizations down the line. ### Test plan Unit tests provided.
1 parent 3c533aa commit 396d722

File tree

4 files changed

+326
-1
lines changed

4 files changed

+326
-1
lines changed

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
1414
FuseBatchNormWithLinearPass,
1515
)
16+
from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import (
17+
RemoveNodesWithKnownOutputs,
18+
)
1619
from executorch.backends.nxp.aten_passes.split_group_convolution import (
1720
SplitGroupConvolution,
1821
)
@@ -34,6 +37,7 @@ def __init__(self, passes: list[PassType] = None):
3437
FuseBatchNormWithLinearPass(),
3538
SplitGroupConvolution(),
3639
SplitGRUBasedOnNumLayers(),
40+
RemoveNodesWithKnownOutputs(),
3741
]
3842

3943
super().__init__(passes)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Collection
7+
8+
import torch
9+
10+
from executorch.backends.nxp.backend.edge_helper import (
11+
try_get_tensor_constant_from_node,
12+
)
13+
from torch._subclasses import FakeTensor, FakeTensorMode
14+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
15+
from torch.export.unflatten import _assign_attr, _AttrKind
16+
from torch.fx import GraphModule, Node
17+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
18+
from torch.nn import Parameter
19+
20+
21+
class RemoveNodesWithKnownOutputs(PassBase):
22+
"""In some situations, a node will always produce the same output data at runtime. If these cases are identified,
23+
the nodes can simply be removed and replaced by a static parameter node, which holds the data the original
24+
node would produce.
25+
This pass identifies some of these cases and performs the replacement.
26+
"""
27+
28+
# Nodes which don't have the `.meta['val']` attribute. The datatype and shape of their inferred output data will
29+
# therefore not be checked against the expected values in the `.meta['val']`.
30+
nodes_without_val_meta = [
31+
torch.ops.aten.empty.memory_format,
32+
]
33+
34+
module: GraphModule
35+
36+
def replace_nodes_in_list_with_their_data(self, list_of_args: list) -> list | None:
37+
"""Replace the nodes in `list_of_args` by their static data. If not all data is available, return `None`.
38+
39+
:param list_of_args: List of arguments of an aten operator. Can include nodes, generic arguments, lists...
40+
:return:`list_of_args` but with tensors replaced by their static data, or `None` if not all data is available.
41+
"""
42+
args_with_data = []
43+
for arg in list_of_args:
44+
match arg:
45+
case Node():
46+
# `arg` is either another operator, a model input, or a static parameter.
47+
data = try_get_tensor_constant_from_node(self.module, arg)
48+
if data is None:
49+
# No static data is available.
50+
return None
51+
52+
args_with_data.append(data)
53+
case list():
54+
nested = self.replace_nodes_in_list_with_their_data(arg)
55+
if nested is None:
56+
return None
57+
args_with_data.append(nested)
58+
59+
case _:
60+
# Generic argument. Not an input from a previous node.
61+
args_with_data.append(arg)
62+
63+
return args_with_data
64+
65+
@staticmethod
66+
def node_is_followed_only_by_getitem_nodes(node: Node) -> bool:
67+
def _is_getitem(node_: Node) -> bool:
68+
return node_.op == "call_function" and node_.target.__name__ == "getitem"
69+
70+
users = list(node.users.keys())
71+
return all(_is_getitem(user) for user in users)
72+
73+
def replace_node_with_static_data(self, node: Node, static_data: Parameter):
74+
"""Remove the given `node` from the graph and replace it with a parameter node containing the `static_data`."""
75+
# Generate a unique name for the new static parameter.
76+
new_name = get_new_attr_name_with_prefix(node.name)(self.module)
77+
78+
# Create the node for the parameter.
79+
param = torch.nn.Parameter(static_data, False)
80+
_assign_attr(param, self.module, str(new_name), _AttrKind.PARAMETER)
81+
with self.module.graph.inserting_before(node):
82+
static_parameter_node = self.module.graph.get_attr(new_name)
83+
84+
with FakeTensorMode() as mode:
85+
# Assign the parameter node its shape and data type.
86+
static_parameter_node.meta["val"] = FakeTensor.from_tensor(
87+
torch.empty(static_data.shape, dtype=static_data.dtype), mode
88+
)
89+
90+
# Replace the old node with the new static parameter.
91+
node.replace_all_uses_with(static_parameter_node)
92+
self.module.graph.erase_node(node)
93+
94+
def replace_following_getitem_nodes_with_static_data(
95+
self, root_node: Node, static_data: Collection[Parameter]
96+
) -> bool:
97+
"""Remove the `root_node` and all `GetItem` nodes that consume its output from the graph, and replace their
98+
uses with parameter nodes containing the provided `static_data`.
99+
If something other than just `GetItem` nodes follow after the `root_node`, nothing is done and `False` is
100+
returned.
101+
102+
:param root_node: The main compute node which is followed only by `GetItem` nodes.
103+
:param static_data: A tuple of static tensors with the data that will be used to replace the `GetItem` nodes
104+
after the `root_node`.
105+
:return: `True` if the replacement was successfully executed. `False` otherwise.
106+
"""
107+
108+
if not self.node_is_followed_only_by_getitem_nodes(root_node):
109+
return False # Unexpected case.
110+
111+
users = list(root_node.users.keys())
112+
if len(users) != len(static_data):
113+
return False # Unexpected missmatch.
114+
115+
# Replace the individual `GetItem` nodes.
116+
for get_item_node in users:
117+
idx = get_item_node.args[1]
118+
self.replace_node_with_static_data(get_item_node, static_data[idx])
119+
120+
# Finally remove the root node from the graph.
121+
self.module.graph.erase_node(root_node)
122+
123+
return True
124+
125+
def data_matches_node_meta(self, node: Node, data: Parameter) -> bool:
126+
"""Verify that the provided `data` tensor has the same shape and datatype as the `node`."""
127+
if node.target not in self.nodes_without_val_meta:
128+
if node.meta["val"].shape != data.shape:
129+
return False # The inferred data has a different shape than expected.
130+
131+
if node.meta["val"].dtype != data.dtype:
132+
return (
133+
False # The inferred data has a different data type than expected.
134+
)
135+
136+
return True
137+
138+
def data_matches_meta_of_following_getitem_nodes(
139+
self, root_node: Node, data: Collection[Parameter]
140+
) -> bool:
141+
"""Verify that the provided `data` tensor has the same shape and datatype as the `GetItem` nodes which consume
142+
the output of the `root_node`.
143+
"""
144+
if not self.node_is_followed_only_by_getitem_nodes(root_node):
145+
return False # Unexpected case
146+
147+
users = list(root_node.users.keys())
148+
return all(
149+
self.data_matches_node_meta(get_item, data[get_item.args[1]])
150+
for get_item in users
151+
)
152+
153+
def call(self, module: GraphModule) -> bool:
154+
self.module = module
155+
made_changes = False
156+
157+
for node in module.graph.nodes:
158+
if node.op != "call_function":
159+
continue # Not a compute operator.
160+
161+
# Try to access the static data for the inputs of the node.
162+
args_with_data = self.replace_nodes_in_list_with_their_data(node.args)
163+
164+
if args_with_data is None:
165+
# Output data inference is not possible.
166+
continue
167+
168+
# All input data is static. Run the operator to compute the input it would produce at runtime.
169+
# noinspection PyBroadException
170+
try:
171+
output = node.target(*args_with_data, **node.kwargs)
172+
173+
if isinstance(output, tuple) or isinstance(output, list):
174+
if not self.data_matches_meta_of_following_getitem_nodes(
175+
node, output
176+
):
177+
continue # The inferred data does not have the expected type/shape.
178+
else:
179+
if not self.data_matches_node_meta(node, output):
180+
continue # The inferred data does not have the expected type/shape.
181+
182+
except Exception:
183+
continue # Failed to infer the data. Continue with the other nodes.
184+
# The output data appears to have been correctly inferred. Create a static parameter node for it.
185+
186+
if isinstance(output, tuple) or isinstance(output, list):
187+
# The node produces multiple outputs (e.g. `split`). If the node is followed only by `GetItem` nodes
188+
# which extract the individual outputs, replace them by the static data.
189+
if self.replace_following_getitem_nodes_with_static_data(node, output):
190+
made_changes = True
191+
192+
else:
193+
self.replace_node_with_static_data(node, output)
194+
made_changes = True # Indicate that changes were made.
195+
196+
return PassResult(module, made_changes)

backends/nxp/backend/edge_helper.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7-
from torch.fx import Node
7+
from torch.fx import GraphModule, Node
88
from torch.nn import Parameter
99

1010

@@ -71,3 +71,19 @@ def _is_dequantize(node_: Node) -> bool:
7171
return _is_dequantize(node) and node_is_static_tensor(
7272
node.args[0], parameters_mapping
7373
)
74+
75+
76+
def try_get_tensor_constant_from_node(
77+
graph_module: GraphModule, node: Node
78+
) -> Parameter | None:
79+
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
80+
if node is None or node.op != "get_attr":
81+
return None
82+
83+
target_atoms = node.target.split(".")
84+
attr_itr = graph_module
85+
for atom in target_atoms:
86+
if not hasattr(attr_itr, atom):
87+
return None
88+
attr_itr = getattr(attr_itr, atom)
89+
return attr_itr
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import numpy as np
9+
import torch
10+
11+
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
12+
NeutronAtenPassManager,
13+
)
14+
from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import (
15+
RemoveNodesWithKnownOutputs,
16+
)
17+
from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import (
18+
SplitGRUBasedOnNumLayers,
19+
)
20+
from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops
21+
from parameterized import parameterized
22+
from torch import nn
23+
24+
25+
class GRUModel(nn.Module):
26+
def __init__(self, num_layers=1):
27+
super().__init__()
28+
self.gru = torch.nn.GRU(8, 8, num_layers=num_layers)
29+
30+
def forward(self, input_):
31+
# `input_` has shape [sequence_length, batch_size, input_size] ([8, 1, 8])
32+
return self.gru(
33+
input_, None
34+
) # The initial hidden is `None`, which will result in a `Zeros` node being added.
35+
36+
37+
class TestRemovingNodesWithKnownOutputs(unittest.TestCase):
38+
__test__ = False # Prevent interfering with PyTest tests.
39+
40+
@classmethod
41+
def setUpClass(cls):
42+
torch.manual_seed(23)
43+
np.random.seed(42)
44+
45+
def test_removing_nodes__zeros(self):
46+
model = GRUModel()
47+
48+
input_shape = (8, 1, 8)
49+
example_input = (torch.ones(input_shape),)
50+
51+
exir_program_aten = torch.export.export(model, example_input).module()
52+
53+
# Make sure the `aten.zeros` is in the model.
54+
assert graph_contains_any_of_ops(
55+
exir_program_aten.graph, [torch.ops.aten.zeros.default]
56+
)
57+
outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)]
58+
59+
# Apply the optimization.
60+
NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten)
61+
62+
# Make sure the `aten.zeros` is no longer in the model.
63+
assert not graph_contains_any_of_ops(
64+
exir_program_aten.graph, [torch.ops.aten.zeros.default]
65+
)
66+
outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)]
67+
68+
# Make sure the model still produces the exact same output.
69+
assert np.allclose(outputs_before[0], outputs_after[0])
70+
assert np.allclose(outputs_before[1], outputs_after[1])
71+
72+
@parameterized.expand([2, 8])
73+
def test_removing_nodes__split(self, num_layers):
74+
# `num_layers > 1` will result in a `split` operator being added. It's input will be a `zeros` operator, which
75+
# provides the static 0s input data.
76+
model = GRUModel(num_layers).eval()
77+
78+
input_shape = (8, 1, 8)
79+
example_input = (torch.ones(input_shape),)
80+
81+
exir_program_aten = torch.export.export(model, example_input).module()
82+
83+
# Apply the pass to split the `aten.gru.input` into multiple instances, and add a `split` node.
84+
NeutronAtenPassManager([SplitGRUBasedOnNumLayers()])(exir_program_aten)
85+
86+
# Make sure the `aten.zeros` and `torch.split` are in the model.
87+
assert graph_contains_any_of_ops(
88+
exir_program_aten.graph, [torch.ops.aten.zeros.default]
89+
)
90+
assert graph_contains_any_of_ops(
91+
exir_program_aten.graph, [torch.ops.aten.split.default]
92+
)
93+
outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)]
94+
95+
# Apply the optimization.
96+
NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten)
97+
98+
# Make sure the `aten.zeros` and `torch.split` are no longer in the model.
99+
assert not graph_contains_any_of_ops(
100+
exir_program_aten.graph, [torch.ops.aten.zeros.default]
101+
)
102+
assert not graph_contains_any_of_ops(
103+
exir_program_aten.graph, [torch.ops.aten.split.default]
104+
)
105+
outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)]
106+
107+
# Make sure the model still produces the exact same output.
108+
assert np.allclose(outputs_before[0], outputs_after[0])
109+
assert np.allclose(outputs_before[1], outputs_after[1])

0 commit comments

Comments
 (0)