1111from torch .fx import Node
1212from torch .fx .passes .infra .pass_base import PassResult
1313
14+ # Operator aliases for better readability.
15+ AddMM = exir_ops .edge .aten .addmm .default
16+ ViewCopy = exir_ops .edge .aten .view_copy .default
17+ MM = exir_ops .edge .aten .mm .default
18+
1419
1520def insert_qdq_pair_after_node (
1621 graph : torch .fx .Graph , anchor : torch .fx .Node , q_params : tuple
@@ -41,15 +46,17 @@ def insert_qdq_pair_after_node(
4146
4247def _is_dequantize (node_ : Node ) -> bool :
4348 return (
44- node_ .op == "call_function"
49+ hasattr (node_ , "op" )
50+ and node_ .op == "call_function"
4551 and node_ .target
4652 == exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
4753 )
4854
4955
5056def _is_quantize (node_ : Node ) -> bool :
5157 return (
52- node_ .op == "call_function"
58+ hasattr (node_ , "op" )
59+ and node_ .op == "call_function"
5360 and node_ .target
5461 == exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
5562 )
@@ -82,20 +89,19 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
8289 ▼
8390 """
8491
85- allowed_auxiliary_nodes = [exir_ops .edge .aten .view_copy .default ]
86-
87- # List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification.
88- allowed_main_cluster_nodes = [
89- exir_ops .edge .aten .addmm .default ,
90- exir_ops .edge .aten .mm .default ,
91- ]
92+ # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
93+ main_cluster_node_to_auxiliary_nodes = {
94+ AddMM : [
95+ ViewCopy ,
96+ ],
97+ MM : [
98+ ViewCopy ,
99+ ],
100+ }
92101
93102 def run (self , graph_module : torch .fx .GraphModule ) -> PassResult :
94103 for aux_node in graph_module .graph .nodes :
95- if (
96- aux_node .op != "call_function"
97- or aux_node .target not in self .allowed_auxiliary_nodes
98- ):
104+ if aux_node .op != "call_function" :
99105 continue
100106
101107 dequantize_node = aux_node .args [0 ]
@@ -109,11 +115,13 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
109115 continue
110116
111117 main_cluster_node = users [0 ]
112- if (
113- main_cluster_node .op != "call_function"
114- or main_cluster_node .target not in self .allowed_main_cluster_nodes
118+ if main_cluster_node .op != "call_function" :
119+ continue
120+
121+ if aux_node .target not in self .main_cluster_node_to_auxiliary_nodes .get (
122+ main_cluster_node .target , []
115123 ):
116- # Unsupported `main_cluster_node` .
124+ # Unsupported main cluster node and auxiliary node pair .
117125 continue
118126
119127 # Make sure the nodes are part of the same QDQ cluster.
@@ -163,29 +171,33 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
163171 ▼
164172 """
165173
166- allowed_auxiliary_nodes = [exir_ops .edge .aten .view_copy .default ]
167-
168- # List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification.
169- allowed_main_cluster_nodes = [
170- exir_ops .edge .aten .addmm .default ,
171- exir_ops .edge .aten .mm .default ,
172- ]
174+ # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
175+ main_cluster_node_to_auxiliary_nodes = {
176+ AddMM : [
177+ ViewCopy ,
178+ ],
179+ MM : [
180+ ViewCopy ,
181+ ],
182+ }
173183
174184 def run (self , graph_module : torch .fx .GraphModule ) -> PassResult :
175185
176186 for aux_node in graph_module .graph .nodes :
177- if (
178- aux_node .op != "call_function"
179- or aux_node .target not in self .allowed_auxiliary_nodes
180- ):
187+ if aux_node .op != "call_function" :
181188 continue
182189
183190 main_cluster_node = aux_node .args [0 ]
184- if (
185- main_cluster_node .op != "call_function"
186- or main_cluster_node .target not in self .allowed_main_cluster_nodes
191+ if not (
192+ hasattr (main_cluster_node , "op" )
193+ and main_cluster_node .op == "call_function"
194+ ):
195+ continue
196+
197+ if aux_node .target not in self .main_cluster_node_to_auxiliary_nodes .get (
198+ main_cluster_node .target , []
187199 ):
188- # Unsupported `main_cluster_node` .
200+ # Unsupported main cluster node and auxiliary node pair .
189201 continue
190202
191203 users = list (aux_node .users .keys ())
0 commit comments