From 3e6515c20f91d1557b25b1c854be7cbb01c0aa1f Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Mon, 1 Sep 2025 11:39:38 +0200 Subject: [PATCH] NXP backend: Update the pass `move_auxiliary_operator_into_separate_qdq_cluster` to allow only specific pairs of main and auxiliary operators. --- ...operator_into_separate_qdq_cluster_pass.py | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py index 7eba60cf2ec..d88684b86f0 100644 --- a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py +++ b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py @@ -11,6 +11,11 @@ from torch.fx import Node from torch.fx.passes.infra.pass_base import PassResult +# Operator aliases for better readability. +AddMM = exir_ops.edge.aten.addmm.default +ViewCopy = exir_ops.edge.aten.view_copy.default +MM = exir_ops.edge.aten.mm.default + def insert_qdq_pair_after_node( graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple @@ -41,7 +46,8 @@ def insert_qdq_pair_after_node( def _is_dequantize(node_: Node) -> bool: return ( - node_.op == "call_function" + hasattr(node_, "op") + and node_.op == "call_function" and node_.target == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default ) @@ -49,7 +55,8 @@ def _is_dequantize(node_: Node) -> bool: def _is_quantize(node_: Node) -> bool: return ( - node_.op == "call_function" + hasattr(node_, "op") + and node_.op == "call_function" and node_.target == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default ) @@ -82,20 +89,19 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): ▼ """ - allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default] - - # List of approved nodes to which the can be connected in order for the pass to make the modification. - allowed_main_cluster_nodes = [ - exir_ops.edge.aten.addmm.default, - exir_ops.edge.aten.mm.default, - ] + # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied. + main_cluster_node_to_auxiliary_nodes = { + AddMM: [ + ViewCopy, + ], + MM: [ + ViewCopy, + ], + } def run(self, graph_module: torch.fx.GraphModule) -> PassResult: for aux_node in graph_module.graph.nodes: - if ( - aux_node.op != "call_function" - or aux_node.target not in self.allowed_auxiliary_nodes - ): + if aux_node.op != "call_function": continue dequantize_node = aux_node.args[0] @@ -109,11 +115,13 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult: continue main_cluster_node = users[0] - if ( - main_cluster_node.op != "call_function" - or main_cluster_node.target not in self.allowed_main_cluster_nodes + if main_cluster_node.op != "call_function": + continue + + if aux_node.target not in self.main_cluster_node_to_auxiliary_nodes.get( + main_cluster_node.target, [] ): - # Unsupported `main_cluster_node`. + # Unsupported main cluster node and auxiliary node pair. continue # Make sure the nodes are part of the same QDQ cluster. @@ -163,29 +171,33 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): ▼ """ - allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default] - - # List of approved nodes to which the `` can be connected in order for the pass to make the modification. - allowed_main_cluster_nodes = [ - exir_ops.edge.aten.addmm.default, - exir_ops.edge.aten.mm.default, - ] + # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied. + main_cluster_node_to_auxiliary_nodes = { + AddMM: [ + ViewCopy, + ], + MM: [ + ViewCopy, + ], + } def run(self, graph_module: torch.fx.GraphModule) -> PassResult: for aux_node in graph_module.graph.nodes: - if ( - aux_node.op != "call_function" - or aux_node.target not in self.allowed_auxiliary_nodes - ): + if aux_node.op != "call_function": continue main_cluster_node = aux_node.args[0] - if ( - main_cluster_node.op != "call_function" - or main_cluster_node.target not in self.allowed_main_cluster_nodes + if not ( + hasattr(main_cluster_node, "op") + and main_cluster_node.op == "call_function" + ): + continue + + if aux_node.target not in self.main_cluster_node_to_auxiliary_nodes.get( + main_cluster_node.target, [] ): - # Unsupported `main_cluster_node`. + # Unsupported main cluster node and auxiliary node pair. continue users = list(aux_node.users.keys())