-
Notifications
You must be signed in to change notification settings - Fork 752
NXP Backend: Add infrastructure for pre processing passes in edge dialect #13183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| # Copyright 2025 NXP | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
|
|
||
| from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass | ||
| from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from torch.fx import Node | ||
| from torch.fx.passes.infra.pass_base import PassResult | ||
|
|
||
|
|
||
| def insert_qdq_pair_after_node( | ||
| graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple | ||
| ): | ||
| # Insert a Quantize node. | ||
| with graph.inserting_after(anchor): | ||
| quantize_op = graph.create_node( | ||
| op="call_function", | ||
| target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, | ||
| args=(), # Will be added later. | ||
| ) | ||
| quantize_op.meta = anchor.meta | ||
|
|
||
| # Insert a Dequantize node. | ||
| with graph.inserting_after(quantize_op): | ||
| dequantize_op = graph.create_node( | ||
| op="call_function", | ||
| target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, | ||
| args=(quantize_op,) + q_params, | ||
| ) | ||
| dequantize_op.meta = quantize_op.meta | ||
| anchor.replace_all_uses_with(dequantize_op) | ||
|
|
||
| # Add this at the end, so the `anchor.replace_all_uses_with(dequantize_op)` does not replace the first use of the | ||
| # `quantize_op`. | ||
| quantize_op.args = (anchor,) + q_params | ||
|
|
||
|
|
||
| def _is_dequantize(node_: Node) -> bool: | ||
| return ( | ||
| node_.op == "call_function" | ||
| and node_.target | ||
| == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default | ||
| ) | ||
|
|
||
|
|
||
| def _is_quantize(node_: Node) -> bool: | ||
| return ( | ||
| node_.op == "call_function" | ||
| and node_.target | ||
| == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default | ||
| ) | ||
|
|
||
|
|
||
| class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): | ||
| """ | ||
| │ | ||
| ┌─────▼──────┐ | ||
| │ │ dequantize │ | ||
| ┌─────▼──────┐ └─────┬──────┘ | ||
| │ dequantize │ ┌─────▼──────┐ | ||
| └─────┬──────┘ │ <aux_node> │ | ||
| ┌─────▼──────┐ └─────┬──────┘ | ||
| │ <aux_node> │ ┌────▼─────┐ ┐ | ||
| └─────┬──────┘ │ quantize │ │ | ||
| ┌──────────▼──────────┐ replaced with └────┬─────┘ │ | ||
| ⋯┤ <main_cluster_node> ├⋯ ──────────────► │ │ newly added nodes | ||
| └──────────┬──────────┘ ┌─────▼──────┐ │ | ||
| ▼ │ dequantize │ │ | ||
| ⋮ └─────┬──────┘ ┘ | ||
| ┌────▼─────┐ ┌──────────▼──────────┐ | ||
| │ quantize │ ⋯┤ <main_cluster_node> ├⋯ | ||
| └────┬─────┘ └──────────┬──────────┘ | ||
| ▼ ▼ | ||
| ⋮ | ||
| ┌────▼─────┐ | ||
| │ quantize │ | ||
| └────┬─────┘ | ||
| ▼ | ||
| """ | ||
|
|
||
| allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default] | ||
|
|
||
| # List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification. | ||
| allowed_main_cluster_nodes = [ | ||
| exir_ops.edge.aten.addmm.default, | ||
| exir_ops.edge.aten.mm.default, | ||
| ] | ||
|
|
||
| def run(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
| for aux_node in graph_module.graph.nodes: | ||
| if ( | ||
| aux_node.op != "call_function" | ||
| or aux_node.target not in self.allowed_auxiliary_nodes | ||
| ): | ||
| continue | ||
|
|
||
| dequantize_node = aux_node.args[0] | ||
| if not _is_dequantize(dequantize_node): | ||
| # Not the intended use case. | ||
| continue | ||
|
|
||
| users = list(aux_node.users.keys()) | ||
| if len(users) != 1: | ||
| # Not the intended use case. | ||
| continue | ||
|
|
||
| main_cluster_node = users[0] | ||
| if ( | ||
| main_cluster_node.op != "call_function" | ||
| or main_cluster_node.target not in self.allowed_main_cluster_nodes | ||
| ): | ||
| # Unsupported `main_cluster_node`. | ||
| continue | ||
|
|
||
| # Make sure the nodes are part of the same QDQ cluster. | ||
| cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node) | ||
| if any( | ||
| node_ not in cluster | ||
| for node_ in [dequantize_node, aux_node, main_cluster_node] | ||
| ): | ||
| continue | ||
|
|
||
| # ---- The nodes follow the pattern described in the header. ---- | ||
|
|
||
| q_params = dequantize_node.args[1:] | ||
| insert_qdq_pair_after_node(graph_module.graph, aux_node, q_params) | ||
|
|
||
| # The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent | ||
| # class will call this pass again. | ||
| return PassResult(graph_module, True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Pop-Korn, in principal this does not differ from the initial draft. You return everytime, you make a modification and the caller @digantdesai, @Pop-Korn noticed in some of the passes the code iterates over a changing graph. E.g. https://github.com/pytorch/executorch/blob/main/backends/xnnpack/_passes/fuse_batch_norm.py#L41 the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Responding to the first paragraph: The second paragraph:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This pattern is used in a lot of places across the code base. This suggests it should be OK - https://github.com/pytorch/pytorch/blob/6f0f4e0c3eacd479864319127915f869f64e1935/torch/fx/graph.py#L1076-L1088 |
||
|
|
||
| # Nothing was changed. | ||
| return PassResult(graph_module, False) | ||
|
|
||
|
|
||
| class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): | ||
| """ | ||
| │ | ||
| ┌─────▼──────┐ | ||
| │ │ dequantize │ | ||
| ┌─────▼──────┐ └─────┬──────┘ | ||
| │ dequantize │ ⋮ | ||
| └─────┬──────┘ ┌──────────▼──────────┐ | ||
| ▼ ⋯┤ <main_cluster_node> ├⋯ | ||
| ⋮ └──────────┬──────────┘ | ||
| ┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐ | ||
| ⋯┤ <main_cluster_node> ├⋯ ──────────────► │ quantize │ │ | ||
| └──────────┬──────────┘ └────┬─────┘ │ | ||
| ┌─────▼──────┐ │ │ newly added nodes | ||
| │ <aux_node> │ ┌─────▼──────┐ │ | ||
| └─────┬──────┘ │ dequantize │ │ | ||
| ┌────▼─────┐ └─────┬──────┘ ┘ | ||
| │ quantize │ ┌─────▼──────┐ | ||
| └────┬─────┘ │ <aux_node> │ | ||
| ▼ └─────┬──────┘ | ||
| ┌────▼─────┐ | ||
| │ quantize │ | ||
| └────┬─────┘ | ||
| ▼ | ||
| """ | ||
|
|
||
| allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default] | ||
|
|
||
| # List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification. | ||
| allowed_main_cluster_nodes = [ | ||
| exir_ops.edge.aten.addmm.default, | ||
| exir_ops.edge.aten.mm.default, | ||
| ] | ||
|
|
||
| def run(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
|
|
||
| for aux_node in graph_module.graph.nodes: | ||
| if ( | ||
| aux_node.op != "call_function" | ||
| or aux_node.target not in self.allowed_auxiliary_nodes | ||
| ): | ||
| continue | ||
|
|
||
| main_cluster_node = aux_node.args[0] | ||
| if ( | ||
| main_cluster_node.op != "call_function" | ||
| or main_cluster_node.target not in self.allowed_main_cluster_nodes | ||
| ): | ||
| # Unsupported `main_cluster_node`. | ||
| continue | ||
|
|
||
| users = list(aux_node.users.keys()) | ||
| if len(users) != 1: | ||
| # Not the intended use case. | ||
| continue | ||
|
|
||
| quantize_node = users[0] | ||
| if not _is_quantize(quantize_node): | ||
| # Not the intended use case. | ||
| continue | ||
|
|
||
| # Make sure the nodes are part of the same QDQ cluster. | ||
| cluster = QDQClusterRecognizer().get_qdq_cluster(main_cluster_node) | ||
| if any( | ||
| node_ not in cluster | ||
| for node_ in [quantize_node, aux_node, main_cluster_node] | ||
| ): | ||
| continue | ||
|
|
||
| # ---- The nodes follow the pattern described in the header. ---- | ||
|
|
||
| q_params = quantize_node.args[1:] | ||
| insert_qdq_pair_after_node(graph_module.graph, main_cluster_node, q_params) | ||
|
|
||
| # The graph has now changed, and we shouldn't keep iterating through it. Return the new graph and the parent | ||
| # class will call this pass again. | ||
| return PassResult(graph_module, True) | ||
|
|
||
| # Nothing was changed. | ||
| return PassResult(graph_module, False) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # Copyright 2025 NXP | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import logging | ||
| from abc import abstractmethod | ||
|
|
||
| import torch | ||
|
|
||
| from executorch.exir.pass_base import ExportPass | ||
| from torch.fx.passes.infra.pass_base import PassResult | ||
|
|
||
|
|
||
| class NeutronEdgePass(ExportPass): | ||
| """Abstract parent class for pre-processing passes on the edge dialect level.""" | ||
|
|
||
| def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
| """Call `self.run()` as long as changes are being made. After a pass modifies the graph, it cannot keep on | ||
| iterating through its nodes, and must return. This method allows the pass to go through the whole model. | ||
| """ | ||
|
|
||
| # Every pass will return once it makes a change to the graph, to avoid traversing and modifying a graph at the | ||
| # same time. Therefore, it must be called multiple times (at most `iteration_limit` times). | ||
| iteration_limit = len(graph_module.graph.nodes) | ||
| modified = False | ||
| for _ in range(iteration_limit): | ||
| res = self.run(graph_module) | ||
| if res.modified: | ||
| modified = True | ||
| graph_module = res.graph_module | ||
|
|
||
| else: | ||
| # No more changes have been made. | ||
| graph_module = self.recompile_module(graph_module) | ||
| return PassResult(graph_module, modified) | ||
|
|
||
| # Iteration limit was reached. | ||
| logging.warning( | ||
| f"The NeutronEdgePass `{self.__class__.__name__}` reached the iteration limit." | ||
| ) | ||
| graph_module = self.recompile_module(graph_module) | ||
| return PassResult(graph_module, modified) | ||
|
|
||
| @abstractmethod | ||
| def run(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
| """Child classes should implement their graph modification here.""" | ||
| pass | ||
|
|
||
| def recompile_module( | ||
| self, graph_module: torch.fx.GraphModule | ||
| ) -> torch.fx.GraphModule: | ||
| """Recompile the graph and re-trace the metadata. This should ensure that the datatypes and shapes are correct.""" | ||
| graph_module.recompile() | ||
| return super().call(graph_module).graph_module |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # Copyright 2025 NXP | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import copy | ||
|
|
||
| from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import ( | ||
| MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass, | ||
| MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass, | ||
| ) | ||
| from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass | ||
| from executorch.exir import EdgeProgramManager | ||
| from executorch.exir.program._program import ( | ||
| _get_updated_graph_signature, | ||
| _get_updated_range_constraints, | ||
| ) | ||
|
|
||
| from torch import nn | ||
| from torch.export import ExportedProgram | ||
| from torch.fx.passes.infra.pass_base import PassResult | ||
| from torch.fx.passes.infra.pass_manager import PassManager | ||
|
|
||
|
|
||
| class NeutronEdgePassManager(PassManager): | ||
|
|
||
| def __init__(self, passes: list[NeutronEdgePass] = None): | ||
| passes: list[NeutronEdgePass] = passes or [ | ||
| MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(), | ||
| MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(), | ||
| ] | ||
|
|
||
| super().__init__( | ||
| passes, | ||
| steps=10, # Empirical value. At most 10 cycles of passes will be run. | ||
| ) | ||
|
|
||
| def _transform_graph_module(self, module: nn.Module) -> PassResult: | ||
| """Apply the passes to a single graph module.""" | ||
| pass_result: PassResult = super().__call__(module) | ||
|
|
||
| graph_module = pass_result.graph_module | ||
| graph_module.graph.eliminate_dead_code() | ||
| graph_module.recompile() | ||
|
|
||
| return pass_result | ||
|
|
||
| def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager: | ||
| """Apply the passes to all graph modules in the edge program.""" | ||
| new_programs: dict[str, ExportedProgram] = {} | ||
|
|
||
| for name, program in epm._edge_programs.items(): | ||
| pass_result = self._transform_graph_module(program.graph_module) | ||
|
|
||
| if pass_result.modified: | ||
| # Create a new exported program. | ||
| new_program = ExportedProgram( | ||
| root=pass_result.graph_module, | ||
| graph=pass_result.graph_module.graph, | ||
| graph_signature=_get_updated_graph_signature( | ||
| program.graph_signature, pass_result.graph_module | ||
| ), | ||
| state_dict=program.state_dict, | ||
| range_constraints=_get_updated_range_constraints( | ||
| pass_result.graph_module | ||
| ), | ||
| module_call_graph=copy.deepcopy(program._module_call_graph), | ||
| example_inputs=program.example_inputs, | ||
| constants=program.constants, | ||
| verifiers=[program.verifier], | ||
| ) | ||
| new_program.graph_module.meta.update(program.graph_module.meta) | ||
| new_program.graph_module.meta.update(pass_result.graph_module.meta) | ||
|
|
||
| else: | ||
| # Keep the old exported program. | ||
| new_program = program | ||
|
|
||
| new_programs[name] = new_program | ||
|
|
||
| if len(new_programs) == 0: | ||
| # No passes were run, return the old EdgeProgramManager. | ||
| return epm | ||
|
|
||
| else: | ||
| # Return a new EdgeProgramManager with the updated programs. | ||
| return EdgeProgramManager( | ||
| new_programs, copy.deepcopy(epm._config_methods), epm.compile_config | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit