33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- # pyre-unsafe
76
87import itertools
98import operator
@@ -52,7 +51,7 @@ def _match_partition_to_node(
5251 raise RuntimeError (f"Cannot find an input node which matches, { node } ." )
5352
5453 def call (self , graph_module : GraphModule ) -> PassResult :
55- matmul_partitions = get_source_partitions (
54+ matmul_partitions_map = get_source_partitions (
5655 graph_module .graph ,
5756 [
5857 torch .matmul ,
@@ -61,7 +60,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6160 None ,
6261 )
6362 matmul_partitions = list (
64- itertools .chain .from_iterable (matmul_partitions .values ())
63+ itertools .chain .from_iterable (matmul_partitions_map .values ())
6564 )
6665 matmul_targets = {
6766 exir_ops .edge .aten .bmm .default ,
@@ -89,7 +88,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8988 # Create new dq-node before matmul
9089 dq_node = create_node (
9190 graph = graph_module .graph ,
92- op_target = cast (EdgeOpOverload , input_node .target ), # type: ignore[arg-type]
91+ op_target = cast (EdgeOpOverload , input_node .target ),
9392 )
9493 dq_node .args = (node , * input_node .args [1 :])
9594 matmul_node .replace_input_with (node , dq_node )
@@ -110,7 +109,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
110109 # Create q-node after matmul
111110 q_node = create_node (
112111 graph = graph_module .graph ,
113- op_target = cast (EdgeOpOverload , partition_output .target ), # type: ignore[arg-type]
112+ op_target = cast (EdgeOpOverload , partition_output .target ),
114113 )
115114 matmul_node .replace_all_uses_with (q_node )
116115 q_node .args = (matmul_node , * partition_output .args [1 :])
0 commit comments