Skip to content

Commit 5ff9157

Browse files
committed
NXP backend: Add infrastructure for running pre-processing passes on edge dialect programs.
1 parent 70d9e94 commit 5ff9157

File tree

3 files changed

+148
-10
lines changed

3 files changed

+148
-10
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from abc import abstractmethod
8+
9+
import torch
10+
from torch.fx.passes.infra.pass_base import PassResult
11+
12+
from executorch.exir.pass_base import ExportPass
13+
14+
15+
class NeutronEdgePass(ExportPass):
16+
""" Abstract parent class for pre-processing passes on the edge dialect level. """
17+
18+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
19+
""" Call `self.run()` as long as changes are being made. After a pass modifies the graph, it cannot keep on
20+
iterating through its nodes, and must return. This method allows the pass to go through the whole model.
21+
"""
22+
23+
# Every pass will return once it makes a change to the graph, to avoid traversing and modifying a graph at the
24+
# same time. Therefore, it must be called multiple times (at most `iteration_limit` times).
25+
iteration_limit = len(graph_module.graph.nodes)
26+
modified = False
27+
for _ in range(iteration_limit):
28+
res = self.run(graph_module)
29+
if res.modified:
30+
modified = True
31+
graph_module = res.graph_module
32+
33+
else:
34+
# No more changes have been made.
35+
graph_module = self.recompile_module(graph_module)
36+
return PassResult(graph_module, modified)
37+
38+
# Iteration limit was reached.
39+
logging.warning(f'The NeutronEdgePass `{self.__class__.__name__}` reached the iteration limit.')
40+
graph_module = self.recompile_module(graph_module)
41+
return PassResult(graph_module, modified)
42+
43+
@abstractmethod
44+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
45+
""" Child classes should implement their graph modification here. """
46+
pass
47+
48+
def recompile_module(
49+
self, graph_module: torch.fx.GraphModule
50+
) -> torch.fx.GraphModule:
51+
""" Recompile the graph and re-trace the metadata. This should ensure that the datatypes and shapes are correct.
52+
"""
53+
graph_module.recompile()
54+
return super().call(graph_module).graph_module
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
8+
from torch import nn
9+
from torch.export import ExportedProgram
10+
from torch.fx.passes.infra.pass_base import PassResult
11+
from torch.fx.passes.infra.pass_manager import PassManager
12+
13+
from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import (
14+
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass,
15+
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
16+
)
17+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
18+
from executorch.exir import EdgeProgramManager
19+
from executorch.exir.program._program import _get_updated_graph_signature, _get_updated_range_constraints
20+
21+
22+
class NeutronEdgePassManager(PassManager):
23+
24+
def __init__(self, passes: list[NeutronEdgePass] = None):
25+
passes: list[NeutronEdgePass] = passes or [
26+
]
27+
28+
super().__init__(
29+
passes,
30+
steps=10 # Empirical value. At most 10 cycles of passes will be run.
31+
)
32+
33+
def _transform_graph_module(self, module: nn.Module) -> PassResult:
34+
""" Apply the passes to a single graph module. """
35+
pass_result: PassResult = super().__call__(module)
36+
37+
graph_module = pass_result.graph_module
38+
graph_module.graph.eliminate_dead_code()
39+
graph_module.recompile()
40+
41+
return pass_result
42+
43+
def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:
44+
""" Apply the passes to all graph modules in the edge program. """
45+
new_programs: dict[str, ExportedProgram] = {}
46+
47+
for name, program in epm._edge_programs.items():
48+
pass_result = self._transform_graph_module(program.graph_module)
49+
50+
if pass_result.modified:
51+
# Create a new exported program.
52+
new_program = ExportedProgram(
53+
root=pass_result.graph_module,
54+
graph=pass_result.graph_module.graph,
55+
graph_signature=_get_updated_graph_signature(
56+
program.graph_signature, pass_result.graph_module
57+
),
58+
state_dict=program.state_dict,
59+
range_constraints=_get_updated_range_constraints(pass_result.graph_module),
60+
module_call_graph=copy.deepcopy(program._module_call_graph),
61+
example_inputs=program.example_inputs,
62+
constants=program.constants,
63+
verifiers=[program.verifier],
64+
)
65+
new_program.graph_module.meta.update(program.graph_module.meta)
66+
new_program.graph_module.meta.update(pass_result.graph_module.meta)
67+
68+
else:
69+
# Keep the old exported program.
70+
new_program = program
71+
72+
new_programs[name] = new_program
73+
74+
if len(new_programs) == 0:
75+
# No passes were run, return the old EdgeProgramManager.
76+
return epm
77+
78+
else:
79+
# Return a new EdgeProgramManager with the updated programs.
80+
return EdgeProgramManager(new_programs, copy.deepcopy(epm._config_methods), epm.compile_config)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
from torch import nn
8+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
79

810
from executorch import exir
11+
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import NeutronEdgePassManager
912
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1013
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
1114
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -14,10 +17,8 @@
1417
EdgeProgramManager,
1518
ExecutorchBackendConfig,
1619
ExecutorchProgramManager,
17-
to_edge_transform_and_lower,
1820
)
19-
from torch import nn
20-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
21+
from executorch.extension.export_util.utils import export_to_edge
2122

2223

2324
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
@@ -49,19 +50,22 @@ def to_quantized_edge_program(
4950
exir_program_aten.module(), calibration_inputs
5051
)
5152

53+
edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
54+
edge_program_manager = export_to_edge(
55+
exir_program_aten__module_quant,
56+
example_input,
57+
edge_compile_config=edge_compile_config,
58+
)
59+
60+
edge_program_manager = NeutronEdgePassManager()(edge_program_manager)
61+
5262
compile_spec = generate_neutron_compile_spec(
5363
target,
5464
operators_not_to_delegate=operators_not_to_delegate,
5565
neutron_converter_flavor=neutron_converter_flavor,
5666
)
5767
partitioner = NeutronPartitioner(compile_spec)
58-
edge_program_manager = to_edge_transform_and_lower(
59-
torch.export.export(
60-
exir_program_aten__module_quant, example_input, strict=True
61-
),
62-
partitioner=[partitioner],
63-
compile_config=EdgeCompileConfig(_check_ir_validity=False),
64-
)
68+
edge_program_manager = edge_program_manager.to_backend(partitioner)
6569

6670
return edge_program_manager
6771

0 commit comments

Comments
 (0)