Skip to content

Commit ff7c5e5

Browse files
authored
NXP Backend: Add infrastructure for pre processing passes in edge dialect (#13183)
### Summary Add infrastructure for running pre-processing passes on edge dialect programs. Add pre-processing pass to move `view_copy` nodes into their own QDQ clusters. ### Test plan Unit test provided.
1 parent 33bd456 commit ff7c5e5

File tree

7 files changed

+487
-13
lines changed

7 files changed

+487
-13
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
9+
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from torch.fx import Node
12+
from torch.fx.passes.infra.pass_base import PassResult
13+
14+
15+
def insert_qdq_pair_after_node(
16+
graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple
17+
):
18+
# Insert a Quantize node.
19+
with graph.inserting_after(anchor):
20+
quantize_op = graph.create_node(
21+
op="call_function",
22+
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
23+
args=(), # Will be added later.
24+
)
25+
quantize_op.meta = anchor.meta
26+
27+
# Insert a Dequantize node.
28+
with graph.inserting_after(quantize_op):
29+
dequantize_op = graph.create_node(
30+
op="call_function",
31+
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
32+
args=(quantize_op,) + q_params,
33+
)
34+
dequantize_op.meta = quantize_op.meta
35+
anchor.replace_all_uses_with(dequantize_op)
36+
37+
# Add this at the end, so the `anchor.replace_all_uses_with(dequantize_op)` does not replace the first use of the
38+
# `quantize_op`.
39+
quantize_op.args = (anchor,) + q_params
40+
41+
42+
def _is_dequantize(node_: Node) -> bool:
43+
return (
44+
node_.op == "call_function"
45+
and node_.target
46+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
47+
)
48+
49+
50+
def _is_quantize(node_: Node) -> bool:
51+
return (
52+
node_.op == "call_function"
53+
and node_.target
54+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
55+
)
56+
57+
58+
class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
59+
"""
60+
61+
┌─────▼──────┐
62+
│ │ dequantize │
63+
┌─────▼──────┐ └─────┬──────┘
64+
│ dequantize │ ┌─────▼──────┐
65+
└─────┬──────┘ │ <aux_node> │
66+
┌─────▼──────┐ └─────┬──────┘
67+
│ <aux_node> │ ┌────▼─────┐ ┐
68+
└─────┬──────┘ │ quantize │ │
69+
┌──────────▼──────────┐ replaced with └────┬─────┘ │
70+
⋯┤ <main_cluster_node> ├⋯ ──────────────► │ │ newly added nodes
71+
└──────────┬──────────┘ ┌─────▼──────┐ │
72+
▼ │ dequantize │ │
73+
⋮ └─────┬──────┘ ┘
74+
┌────▼─────┐ ┌──────────▼──────────┐
75+
│ quantize │ ⋯┤ <main_cluster_node> ├⋯
76+
└────┬─────┘ └──────────┬──────────┘
77+
▼ ▼
78+
79+
┌────▼─────┐
80+
│ quantize │
81+
└────┬─────┘
82+
83+
"""
84+
85+
allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]
86+
87+
# List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification.
88+
allowed_main_cluster_nodes = [
89+
exir_ops.edge.aten.addmm.default,
90+
exir_ops.edge.aten.mm.default,
91+
]
92+
93+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
94+
for aux_node in graph_module.graph.nodes:
95+
if (
96+
aux_node.op != "call_function"
97+
or aux_node.target not in self.allowed_auxiliary_nodes
98+
):
99+
continue
100+
101+
dequantize_node = aux_node.args[0]
102+
if not _is_dequantize(dequantize_node):
103+
# Not the intended use case.
104+
continue
105+
106+
users = list(aux_node.users.keys())
107+
if len(users) != 1:
108+
# Not the intended use case.
109+
continue
110+
111+
main_cluster_node = users[0]
112+
if (
113+
main_cluster_node.op != "call_function"
114+
or main_cluster_node.target not in self.allowed_main_cluster_nodes
115+
):
116+
# Unsupported `main_cluster_node`.
117+
continue
118+
119+
# Make sure the nodes are part of the same QDQ cluster.
120+
cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node)
121+
if any(
122+
node_ not in cluster
123+
for node_ in [dequantize_node, aux_node, main_cluster_node]
124+
):
125+
continue
126+
127+
# ---- The nodes follow the pattern described in the header. ----
128+
129+
q_params = dequantize_node.args[1:]
130+
insert_qdq_pair_after_node(graph_module.graph, aux_node, q_params)
131+
132+
# The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent
133+
# class will call this pass again.
134+
return PassResult(graph_module, True)
135+
136+
# Nothing was changed.
137+
return PassResult(graph_module, False)
138+
139+
140+
class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
141+
"""
142+
143+
┌─────▼──────┐
144+
│ │ dequantize │
145+
┌─────▼──────┐ └─────┬──────┘
146+
│ dequantize │ ⋮
147+
└─────┬──────┘ ┌──────────▼──────────┐
148+
▼ ⋯┤ <main_cluster_node> ├⋯
149+
⋮ └──────────┬──────────┘
150+
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
151+
⋯┤ <main_cluster_node> ├⋯ ──────────────► │ quantize │ │
152+
└──────────┬──────────┘ └────┬─────┘ │
153+
┌─────▼──────┐ │ │ newly added nodes
154+
│ <aux_node> │ ┌─────▼──────┐ │
155+
└─────┬──────┘ │ dequantize │ │
156+
┌────▼─────┐ └─────┬──────┘ ┘
157+
│ quantize │ ┌─────▼──────┐
158+
└────┬─────┘ │ <aux_node> │
159+
▼ └─────┬──────┘
160+
┌────▼─────┐
161+
│ quantize │
162+
└────┬─────┘
163+
164+
"""
165+
166+
allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]
167+
168+
# List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification.
169+
allowed_main_cluster_nodes = [
170+
exir_ops.edge.aten.addmm.default,
171+
exir_ops.edge.aten.mm.default,
172+
]
173+
174+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
175+
176+
for aux_node in graph_module.graph.nodes:
177+
if (
178+
aux_node.op != "call_function"
179+
or aux_node.target not in self.allowed_auxiliary_nodes
180+
):
181+
continue
182+
183+
main_cluster_node = aux_node.args[0]
184+
if (
185+
main_cluster_node.op != "call_function"
186+
or main_cluster_node.target not in self.allowed_main_cluster_nodes
187+
):
188+
# Unsupported `main_cluster_node`.
189+
continue
190+
191+
users = list(aux_node.users.keys())
192+
if len(users) != 1:
193+
# Not the intended use case.
194+
continue
195+
196+
quantize_node = users[0]
197+
if not _is_quantize(quantize_node):
198+
# Not the intended use case.
199+
continue
200+
201+
# Make sure the nodes are part of the same QDQ cluster.
202+
cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node)
203+
if any(
204+
node_ not in cluster
205+
for node_ in [quantize_node, aux_node, main_cluster_node]
206+
):
207+
continue
208+
209+
# ---- The nodes follow the pattern described in the header. ----
210+
211+
q_params = quantize_node.args[1:]
212+
insert_qdq_pair_after_node(graph_module.graph, main_cluster_node, q_params)
213+
214+
# The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent
215+
# class will call this pass again.
216+
return PassResult(graph_module, True)
217+
218+
# Nothing was changed.
219+
return PassResult(graph_module, False)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from abc import abstractmethod
8+
9+
import torch
10+
11+
from executorch.exir.pass_base import ExportPass
12+
from torch.fx.passes.infra.pass_base import PassResult
13+
14+
15+
class NeutronEdgePass(ExportPass):
16+
"""Abstract parent class for pre-processing passes on the edge dialect level."""
17+
18+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
19+
"""Call `self.run()` as long as changes are being made. After a pass modifies the graph, it cannot keep on
20+
iterating through its nodes, and must return. This method allows the pass to go through the whole model.
21+
"""
22+
23+
# Every pass will return once it makes a change to the graph, to avoid traversing and modifying a graph at the
24+
# same time. Therefore, it must be called multiple times (at most `iteration_limit` times).
25+
iteration_limit = len(graph_module.graph.nodes)
26+
modified = False
27+
for _ in range(iteration_limit):
28+
res = self.run(graph_module)
29+
if res.modified:
30+
modified = True
31+
graph_module = res.graph_module
32+
33+
else:
34+
# No more changes have been made.
35+
graph_module = self.recompile_module(graph_module)
36+
return PassResult(graph_module, modified)
37+
38+
# Iteration limit was reached.
39+
logging.warning(
40+
f"The NeutronEdgePass `{self.__class__.__name__}` reached the iteration limit."
41+
)
42+
graph_module = self.recompile_module(graph_module)
43+
return PassResult(graph_module, modified)
44+
45+
@abstractmethod
46+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
47+
"""Child classes should implement their graph modification here."""
48+
pass
49+
50+
def recompile_module(
51+
self, graph_module: torch.fx.GraphModule
52+
) -> torch.fx.GraphModule:
53+
"""Recompile the graph and re-trace the metadata. This should ensure that the datatypes and shapes are correct."""
54+
graph_module.recompile()
55+
return super().call(graph_module).graph_module
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
8+
from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import (
9+
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass,
10+
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
11+
)
12+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
13+
from executorch.exir import EdgeProgramManager
14+
from executorch.exir.program._program import (
15+
_get_updated_graph_signature,
16+
_get_updated_range_constraints,
17+
)
18+
19+
from torch import nn
20+
from torch.export import ExportedProgram
21+
from torch.fx.passes.infra.pass_base import PassResult
22+
from torch.fx.passes.infra.pass_manager import PassManager
23+
24+
25+
class NeutronEdgePassManager(PassManager):
26+
27+
def __init__(self, passes: list[NeutronEdgePass] = None):
28+
passes: list[NeutronEdgePass] = passes or [
29+
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
30+
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
31+
]
32+
33+
super().__init__(
34+
passes,
35+
steps=10, # Empirical value. At most 10 cycles of passes will be run.
36+
)
37+
38+
def _transform_graph_module(self, module: nn.Module) -> PassResult:
39+
"""Apply the passes to a single graph module."""
40+
pass_result: PassResult = super().__call__(module)
41+
42+
graph_module = pass_result.graph_module
43+
graph_module.graph.eliminate_dead_code()
44+
graph_module.recompile()
45+
46+
return pass_result
47+
48+
def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:
49+
"""Apply the passes to all graph modules in the edge program."""
50+
new_programs: dict[str, ExportedProgram] = {}
51+
52+
for name, program in epm._edge_programs.items():
53+
pass_result = self._transform_graph_module(program.graph_module)
54+
55+
if pass_result.modified:
56+
# Create a new exported program.
57+
new_program = ExportedProgram(
58+
root=pass_result.graph_module,
59+
graph=pass_result.graph_module.graph,
60+
graph_signature=_get_updated_graph_signature(
61+
program.graph_signature, pass_result.graph_module
62+
),
63+
state_dict=program.state_dict,
64+
range_constraints=_get_updated_range_constraints(
65+
pass_result.graph_module
66+
),
67+
module_call_graph=copy.deepcopy(program._module_call_graph),
68+
example_inputs=program.example_inputs,
69+
constants=program.constants,
70+
verifiers=[program.verifier],
71+
)
72+
new_program.graph_module.meta.update(program.graph_module.meta)
73+
new_program.graph_module.meta.update(pass_result.graph_module.meta)
74+
75+
else:
76+
# Keep the old exported program.
77+
new_program = program
78+
79+
new_programs[name] = new_program
80+
81+
if len(new_programs) == 0:
82+
# No passes were run, return the old EdgeProgramManager.
83+
return epm
84+
85+
else:
86+
# Return a new EdgeProgramManager with the updated programs.
87+
return EdgeProgramManager(
88+
new_programs, copy.deepcopy(epm._config_methods), epm.compile_config
89+
)

0 commit comments

Comments
 (0)