55# LICENSE file in the root directory of this source tree.
66
77import itertools
8- from typing import Any , Dict , List
98
109import torch
1110from executorch .backends .arm ._passes .arm_pass_utils import create_node
1211from executorch .backends .arm .tosa_quant_utils import dq_op , q_op
1312from executorch .exir .dialects ._ops import ops as exir_ops
1413from executorch .exir .pass_base import ExportPass , PassResult
1514from torch .fx import GraphModule
16- from torch .fx .passes .utils .source_matcher_utils import (
17- get_source_partitions ,
18- SourcePartition ,
19- )
15+ from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
2016
2117
2218class AnnotateDecomposedMatmulPass (ExportPass ):
@@ -28,8 +24,8 @@ class AnnotateDecomposedMatmulPass(ExportPass):
2824 matmul-op (can be mm or bmm).
2925 """
3026
31- def call (self , graph_module : GraphModule ):
32- matmul_partitions : Dict [ Any , List [ SourcePartition ]] = get_source_partitions (
27+ def call (self , graph_module : GraphModule ) -> PassResult :
28+ matmul_partitions = get_source_partitions (
3329 graph_module .graph ,
3430 [
3531 torch .matmul ,
@@ -56,7 +52,7 @@ def call(self, graph_module: GraphModule):
5652 input_node = partition .input_nodes [i ]
5753 matmul_input_node = matmul_args [i ]
5854 # Remove partition input dq-node
59- input_node .replace_all_uses_with (input_node .args [0 ])
55+ input_node .replace_all_uses_with (input_node .all_input_nodes [0 ])
6056 graph_module .graph .erase_node (input_node )
6157 input_node_qargs = input_node .args [1 :]
6258 with graph_module .graph .inserting_before (matmul_node ):
@@ -81,7 +77,9 @@ def call(self, graph_module: GraphModule):
8177 matmul_node .replace_all_uses_with (q_node )
8278 q_node .args = (matmul_node , * output_node_qargs )
8379 # Remove partition output q-node
84- partition_output .replace_all_uses_with (partition_output .args [0 ])
80+ partition_output .replace_all_uses_with (
81+ partition_output .all_input_nodes [0 ]
82+ )
8583 graph_module .graph .erase_node (partition_output )
8684
8785 # retrace the graph to update the fake tensor types
0 commit comments