Skip to content

Commit 9a7fb42

Browse files
Arm backend: Fix torch.matmul() failures for 2D tensor inputs (#14624)
- ConvertMmToBmmPass converts an MM node to BMM nodes, turns input and output tensors from rank-2 to rank-3 via unsqueeze/squeeze, and inserts q-dq before and after BMM node when necessary. - After ConvertMmToBmmPass: ``` x -> q -> dq -> unsqueeze -> q_2 -> dq_2 -> \ bmm -> q_4 -> dq_4 / y -> q_1 -> dq_1 -> unsqueeze -> q_3 -> dq_3 -> ``` - Therefore, if the original matmul was 2D, the bmm already has DQ nodes on its inputs and Q node on its output. If AnnotateDecomposedMatmulPass (#10654) is still applied in this case, it produces illegal sequences such as: x -> q -> unsqueeze -> q_2 (invalid) - Fix by checking whether the BMM is already surrounded by DQ nodes on its inputs and Q nodes on its output. Change-Id: I9949d59b0b4a96fa34a88b0734014567ea6f24cc cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Yufeng Shi <[email protected]> Co-authored-by: Oscar Andersson <[email protected]>
1 parent 75ebd05 commit 9a7fb42

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
7373
node for node in partition.nodes if node.target in matmul_targets
7474
][0]
7575

76-
if quantized_input:
76+
if quantized_input and not all(
77+
input_node.target in DQ_OPS
78+
for input_node in matmul_node.all_input_nodes
79+
):
7780
matmul_args = matmul_node.all_input_nodes
7881
for node in matmul_args:
7982
# Find the dq-node connected to this mm/bmm arg
@@ -99,7 +102,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
99102

100103
partition_output = list(partition.output_nodes[0].users)[0]
101104
quantized_output = partition_output.target in Q_OPS
102-
if quantized_output:
105+
if quantized_output and not all(
106+
user.target in Q_OPS for user in matmul_node.users
107+
):
103108
with graph_module.graph.inserting_after(matmul_node):
104109
# Create q-node after matmul
105110
q_node = create_node(

backends/arm/test/ops/test_matmul.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
class MatMul(torch.nn.Module):
2424
test_data_generators = {
25+
"rand_rand_2d": lambda: (torch.rand(5, 5), torch.rand(5, 2)),
2526
"rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
2627
"rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
2728
}
@@ -32,6 +33,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
3233

3334
class MatMulSingleInput(torch.nn.Module):
3435
test_data_generators = {
36+
"rand_2d": lambda: (torch.rand(5, 5),),
3537
"rand_3d": lambda: (torch.rand(2, 5, 5),),
3638
"rand_4d": lambda: (torch.rand(1, 2, 5, 5),),
3739
}
@@ -42,6 +44,11 @@ def forward(self, x: torch.Tensor):
4244

4345
class MatMulCombo(torch.nn.Module):
4446
test_data_generators = {
47+
"rand_rand_rand_2d": lambda: (
48+
torch.rand(5, 5),
49+
torch.rand(5, 2),
50+
torch.rand(2, 5),
51+
),
4552
"rand_rand_rand_3d": lambda: (
4653
torch.rand(2, 5, 5),
4754
torch.rand(2, 5, 2),

0 commit comments

Comments
 (0)