66
77import itertools
88
9+ from typing import List
10+
911import torch
1012from executorch .backends .arm ._passes .arm_pass_utils import create_node
11- from executorch .backends .arm .tosa_quant_utils import dq_op , q_op
13+
14+ from executorch .backends .arm .tosa_quant_utils import dq_op , q_op , QuantArgs
1215from executorch .exir .dialects ._ops import ops as exir_ops
1316from executorch .exir .pass_base import ExportPass , PassResult
1417from torch .fx import GraphModule
@@ -24,6 +27,22 @@ class AnnotateDecomposedMatmulPass(ExportPass):
2427 matmul-op (can be mm or bmm).
2528 """
2629
30+ def _match_partition_to_node (
31+ self , node : torch .fx .Node , partitioned_inputs : List [torch .fx .Node ]
32+ ) -> torch .fx .Node :
33+ """
34+ The partition.input_nodes order is not guaranteed. Compare these
35+ with the matmul node inputs coming in and return the nodes
36+ in the correct order.
37+ """
38+ if not node or node in partitioned_inputs or node .op == "placeholder" :
39+ return node
40+ else :
41+ return self ._match_partition_to_node (
42+ node .all_input_nodes [0 ], partitioned_inputs
43+ )
44+ raise RuntimeError (f"Cannot find an input node which matches, { node } ." )
45+
2746 def call (self , graph_module : GraphModule ) -> PassResult :
2847 matmul_partitions = get_source_partitions (
2948 graph_module .graph ,
@@ -45,28 +64,36 @@ def call(self, graph_module: GraphModule) -> PassResult:
4564 matmul_node = [
4665 node for node in partition .nodes if node .target in matmul_targets
4766 ][0 ]
67+
4868 if quantized_input :
4969 matmul_args = matmul_node .all_input_nodes
50- for i in range (len (matmul_args )):
51- input_node = partition .input_nodes [i ]
52- matmul_input_node = matmul_args [i ]
70+ for node in matmul_args :
71+ input_node = self ._match_partition_to_node (
72+ node , partition .input_nodes
73+ )
74+
5375 # Remove partition input dq-node
5476 input_node .replace_all_uses_with (input_node .all_input_nodes [0 ])
5577 graph_module .graph .erase_node (input_node )
56- input_node_qargs = input_node .args [1 :]
78+ input_node_qargs = QuantArgs .from_operator (
79+ input_node .target , input_node .args
80+ )
81+
5782 with graph_module .graph .inserting_before (matmul_node ):
5883 # Create new dq-node before matmul
5984 dq_node = create_node (
6085 graph = graph_module .graph ,
6186 op_target = dq_op ,
6287 )
63- dq_node .args = (matmul_input_node , * input_node_qargs )
64- matmul_node .replace_input_with (matmul_input_node , dq_node )
88+ dq_node .args = (node , * input_node_qargs )
89+ matmul_node .replace_input_with (node , dq_node )
6590
6691 partition_output = list (partition .output_nodes [0 ].users )[0 ]
6792 quantized_output = partition_output .target == q_op
6893 if quantized_output :
69- output_node_qargs = partition_output .args [1 :]
94+ output_node_qargs = QuantArgs .from_operator (
95+ partition_output .target , partition_output .args
96+ )
7097 with graph_module .graph .inserting_after (matmul_node ):
7198 # Create q-node after matmul
7299 q_node = create_node (
0 commit comments