Skip to content

Commit 267f024

Browse files
authored
Arm backend: Ensure input nodes are in correct order for bmm/matmul (#8095)
- Ensures inputs come are processed in the correct order - Remove flaky test decorator for test_bmm and test_mm
1 parent 493e290 commit 267f024

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
import itertools
88

9+
from typing import List
10+
911
import torch
1012
from executorch.backends.arm._passes.arm_pass_utils import create_node
11-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
13+
14+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
1215
from executorch.exir.dialects._ops import ops as exir_ops
1316
from executorch.exir.pass_base import ExportPass, PassResult
1417
from torch.fx import GraphModule
@@ -24,6 +27,22 @@ class AnnotateDecomposedMatmulPass(ExportPass):
2427
matmul-op (can be mm or bmm).
2528
"""
2629

30+
def _match_partition_to_node(
31+
self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]
32+
) -> torch.fx.Node:
33+
"""
34+
The partition.input_nodes order is not guaranteed. Compare these
35+
with the matmul node inputs coming in and return the nodes
36+
in the correct order.
37+
"""
38+
if not node or node in partitioned_inputs or node.op == "placeholder":
39+
return node
40+
else:
41+
return self._match_partition_to_node(
42+
node.all_input_nodes[0], partitioned_inputs
43+
)
44+
raise RuntimeError(f"Cannot find an input node which matches, {node}.")
45+
2746
def call(self, graph_module: GraphModule) -> PassResult:
2847
matmul_partitions = get_source_partitions(
2948
graph_module.graph,
@@ -45,28 +64,36 @@ def call(self, graph_module: GraphModule) -> PassResult:
4564
matmul_node = [
4665
node for node in partition.nodes if node.target in matmul_targets
4766
][0]
67+
4868
if quantized_input:
4969
matmul_args = matmul_node.all_input_nodes
50-
for i in range(len(matmul_args)):
51-
input_node = partition.input_nodes[i]
52-
matmul_input_node = matmul_args[i]
70+
for node in matmul_args:
71+
input_node = self._match_partition_to_node(
72+
node, partition.input_nodes
73+
)
74+
5375
# Remove partition input dq-node
5476
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
5577
graph_module.graph.erase_node(input_node)
56-
input_node_qargs = input_node.args[1:]
78+
input_node_qargs = QuantArgs.from_operator(
79+
input_node.target, input_node.args
80+
)
81+
5782
with graph_module.graph.inserting_before(matmul_node):
5883
# Create new dq-node before matmul
5984
dq_node = create_node(
6085
graph=graph_module.graph,
6186
op_target=dq_op,
6287
)
63-
dq_node.args = (matmul_input_node, *input_node_qargs)
64-
matmul_node.replace_input_with(matmul_input_node, dq_node)
88+
dq_node.args = (node, *input_node_qargs)
89+
matmul_node.replace_input_with(node, dq_node)
6590

6691
partition_output = list(partition.output_nodes[0].users)[0]
6792
quantized_output = partition_output.target == q_op
6893
if quantized_output:
69-
output_node_qargs = partition_output.args[1:]
94+
output_node_qargs = QuantArgs.from_operator(
95+
partition_output.target, partition_output.args
96+
)
7097
with graph_module.graph.inserting_after(matmul_node):
7198
# Create q-node after matmul
7299
q_node = create_node(

backends/arm/test/ops/test_bmm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,16 @@ def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
134134
self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)
135135

136136
@parameterized.expand(MatMul.test_data_generators)
137-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
138137
def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
139138
test_data = test_data_generator()
140139
self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)
141140

142141
@parameterized.expand(BMM.test_data_generators)
143-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
144142
def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
145143
test_data = test_data_generator()
146144
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
147145

148146
@parameterized.expand(BMMSingleInput.test_data_generators)
149-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
150147
def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
151148
test_data = test_data_generator()
152149
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
@@ -162,7 +159,6 @@ def test_bmm_u55_BI_xfails(self, test_data_generator: Callable[[], Tuple]):
162159

163160
@parameterized.expand(BMM.test_data_generators)
164161
@pytest.mark.corstone_fvp
165-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
166162
def test_bmm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
167163
test_data = test_data_generator()
168164
self._test_bmm_ethosu_BI_pipeline(
@@ -183,7 +179,6 @@ def test_bmm_single_input_u55_BI_xfails(
183179

184180
@parameterized.expand(BMMSingleInput.test_data_generators)
185181
@pytest.mark.corstone_fvp
186-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
187182
def test_bmm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
188183
test_data = test_data_generator()
189184
self._test_bmm_ethosu_BI_pipeline(

backends/arm/test/ops/test_mm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
126126
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
127127

128128
@parameterized.expand(MMSingleInput.test_data_generators)
129-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
130129
def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
131130
test_data = test_data_generator()
132131
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)
@@ -150,15 +149,13 @@ def test_mm_single_input_u55_BI(self, test_data_generator: Callable[[], Tuple]):
150149
)
151150

152151
@parameterized.expand(MM.test_data_generators)
153-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
154152
def test_mm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
155153
test_data = test_data_generator()
156154
self._test_mm_ethosu_BI_pipeline(
157155
common.get_u85_compile_spec(), self.MM(), test_data
158156
)
159157

160158
@parameterized.expand(MMSingleInput.test_data_generators)
161-
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
162159
def test_mm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
163160
test_data = test_data_generator()
164161
self._test_mm_ethosu_BI_pipeline(

0 commit comments

Comments
 (0)