11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
65
76# pyre-unsafe
87
98import itertools
10-
9+ import operator
1110from typing import List
1211
1312import torch
2221
2322class AnnotateDecomposedMatmulPass (ExportPass ):
2423 """
25- torch.matmul can be decomposed in many ways, for instance:
24+ torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
2625 dq -> matmul -> q can become
2726 dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
2827 difficult. This helper function find all matmul partitions and annotate its
@@ -50,6 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
5049 graph_module .graph ,
5150 [
5251 torch .matmul ,
52+ operator .matmul ,
5353 ],
5454 None ,
5555 )
@@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7070 if quantized_input :
7171 matmul_args = matmul_node .all_input_nodes
7272 for node in matmul_args :
73+ # Find the dq-node connected to this mm/bmm arg
7374 input_node = self ._match_partition_to_node (
7475 node , partition .input_nodes
7576 )
76-
77- # Remove partition input dq-node
78- input_node .replace_all_uses_with (input_node .all_input_nodes [0 ])
79- graph_module .graph .erase_node (input_node )
8077 input_node_qargs = QuantArgs .from_operator (
8178 input_node .target , input_node .args
8279 )
83-
80+ # Insert new dq-node just before the mm/bmm with input_node's qparams
8481 with graph_module .graph .inserting_before (matmul_node ):
8582 # Create new dq-node before matmul
8683 dq_node = create_node (
@@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
9087 dq_node .args = (node , * input_node_qargs )
9188 matmul_node .replace_input_with (node , dq_node )
9289
90+ for partition_input in partition .input_nodes :
91+ # Remove partition input dq-node
92+ partition_input .replace_all_uses_with (
93+ partition_input .all_input_nodes [0 ]
94+ )
95+ graph_module .graph .erase_node (partition_input )
96+
9397 partition_output = list (partition .output_nodes [0 ].users )[0 ]
9498 quantized_output = partition_output .target == q_op
9599 if quantized_output :
0 commit comments