Skip to content

Commit 42923f3

Browse files
committed
Cortex-M backend: Add mul and linear tests
Minor included fixes: - Make quantized_linear_fusion_pass an XNNPACK pass to initialize it with an exported program - Add TO_EXECUTORCH as a valid stage after RUN_PASSES - Add ramp_tensor function to simplify creating dummy data Signed-off-by: Adrian Lundell <[email protected]> Change-Id: Id13be6427390483aa1df1b76fc363ae4d0eae876
1 parent f24351a commit 42923f3

File tree

6 files changed

+381
-14
lines changed

6 files changed

+381
-14
lines changed

backends/cortex_m/passes/quantized_linear_fusion_pass.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -19,17 +20,18 @@
1920
)
2021

2122
from executorch.backends.transforms.utils import create_mutable_buffer, get_param_tensor
23+
24+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
2225
from executorch.exir import ExportedProgram
2326
from executorch.exir.dialects._ops import ops as exir_ops
24-
from executorch.exir.pass_base import ExportPass
2527
from torch.fx import Node
2628
from torch.fx.passes.infra.pass_manager import PassResult
2729

2830
logger = logging.getLogger("quantized_linear_fusion_pass")
2931
logger.setLevel(logging.INFO)
3032

3133

32-
class QuantizedLinearFusionPass(ExportPass):
34+
class QuantizedLinearFusionPass(XNNPACKPass):
3335
"""
3436
Cortex-M backend pass that fuses quantized linear-like patterns.
3537
Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize
@@ -44,8 +46,7 @@ class QuantizedLinearFusionPass(ExportPass):
4446
requires_exported_program = True
4547

4648
def __init__(self, exported_program: ExportedProgram):
47-
super().__init__()
48-
self._exported_program = exported_program
49+
super().__init__(exported_program)
4950
self.nodes_to_erase = []
5051

5152
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

backends/cortex_m/test/ops/test_add.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import torch
88
from executorch.backends.arm.test.common import parametrize
9-
from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase
9+
from executorch.backends.cortex_m.test.tester import (
10+
CortexMTester,
11+
McuTestCase,
12+
ramp_tensor,
13+
)
1014
from executorch.backends.test.suite.operators.test_add import Model, ModelAlpha
1115

1216

@@ -80,19 +84,19 @@ class CortexMAlphaAdd(ModelAlpha):
8084
),
8185
"self_rank_2_pos": McuTestCase(
8286
CortexMSelfAdd(),
83-
(torch.linspace(0, 1000, 10).reshape((10, 1)),),
87+
(ramp_tensor(0, 1000, (10, 1)),),
8488
),
8589
"self_rank_3_neg": McuTestCase(
8690
CortexMSelfAdd(),
87-
(torch.linspace(-100, 0, 8).reshape((2, 2, 2)),),
91+
(ramp_tensor(-100, 0, (2, 2, 2)),),
8892
),
8993
"self_rank_4_small": McuTestCase(
9094
CortexMSelfAdd(),
91-
(torch.linspace(-0.1, 0.1, 16).reshape(2, 2, 2, 2),),
95+
(ramp_tensor(-0.1, 0.1, (2, 2, 2, 2)),),
9296
),
9397
"self_rank_5": McuTestCase(
9498
CortexMSelfAdd(),
95-
(torch.linspace(-5, 5, 32).reshape(2, 2, 2, 2, 2),),
99+
(ramp_tensor(-5, 5, (2, 2, 2, 2, 2)),),
96100
),
97101
"scalar_scalar": McuTestCase(
98102
CortexMScalarAdd(),
@@ -117,15 +121,15 @@ class CortexMAlphaAdd(ModelAlpha):
117121
"broadcast_3": McuTestCase(
118122
CortexMTensorAdd(),
119123
(
120-
torch.linspace(-2, 2, 4).reshape(2, 1, 2, 1),
121-
torch.linspace(-5, 5, 4).reshape(1, 2, 1, 2),
124+
ramp_tensor(-2, 2, (2, 1, 2, 1)),
125+
ramp_tensor(-5, 5, (1, 2, 1, 2)),
122126
),
123127
),
124128
"alpha": McuTestCase(
125129
CortexMAlphaAdd(0.5),
126130
(
127-
torch.linspace(-10, 10, 20).reshape(4, 5),
128-
torch.linspace(-20, 20, 20).reshape(4, 5),
131+
ramp_tensor(-10, 10, (4, 5)),
132+
ramp_tensor(-20, 20, (4, 5)),
129133
),
130134
),
131135
}
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
from executorch.backends.arm.test.common import parametrize
9+
from executorch.backends.cortex_m.test.tester import (
10+
CortexMTester,
11+
McuTestCase,
12+
ramp_tensor,
13+
)
14+
15+
16+
class CortexMMm(torch.nn.Module):
17+
def forward(self, x, y):
18+
return torch.mm(x, y)
19+
20+
ops_before_transforms = {
21+
"executorch_exir_dialects_edge__ops_aten_mm_default": 1,
22+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
23+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
24+
}
25+
26+
ops_after_transforms = {
27+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1,
28+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
29+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
30+
}
31+
32+
33+
class CortexMBmm(torch.nn.Module):
34+
def forward(self, x, y):
35+
return torch.bmm(x, y)
36+
37+
ops_before_transforms = {
38+
"executorch_exir_dialects_edge__ops_aten_bmm_default": 1,
39+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
40+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
41+
}
42+
43+
ops_after_transforms = {
44+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1,
45+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
46+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
47+
}
48+
49+
50+
class CortexMAddmm(torch.nn.Module):
51+
def forward(self, x, y, z, alpha=None, beta=None):
52+
return torch.addmm(beta, x, alpha, y, z)
53+
54+
ops_before_transforms = {
55+
"executorch_exir_dialects_edge__ops_aten_addmm_default": 1,
56+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
57+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
58+
}
59+
60+
ops_after_transforms = {
61+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1,
62+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
63+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
64+
}
65+
66+
67+
class CortexMAt(CortexMMm):
68+
def forward(self, x, y):
69+
return x @ y
70+
71+
72+
class CortexMMatmul(CortexMMm):
73+
def forward(self, x, y):
74+
return torch.matmul(x, y)
75+
76+
77+
class CortexMLinear(CortexMMatmul):
78+
def __init__(self, *args, **kwargs):
79+
super().__init__()
80+
self.linear = torch.nn.Linear(*args, bias=False)
81+
82+
def forward(self, x):
83+
return self.linear(x)
84+
85+
86+
class CortexMLinearBias(CortexMAddmm):
87+
def __init__(self, *args, **kwargs):
88+
super().__init__()
89+
self.linear = torch.nn.Linear(*args, bias=True)
90+
91+
def forward(self, x):
92+
return self.linear(x)
93+
94+
95+
test_cases = {
96+
"mm": McuTestCase(
97+
model=CortexMMm(),
98+
example_inputs=(
99+
ramp_tensor(0, 10, (1, 16)),
100+
ramp_tensor(0, 10, (16, 16)),
101+
),
102+
),
103+
"bmm": McuTestCase(
104+
model=CortexMBmm(),
105+
example_inputs=(
106+
ramp_tensor(0, 10, (1, 16, 16)),
107+
ramp_tensor(0, 10, (1, 16, 16)),
108+
),
109+
),
110+
"addmm": McuTestCase(
111+
model=CortexMAddmm(),
112+
example_inputs=(
113+
ramp_tensor(0, 10, (1, 16)),
114+
ramp_tensor(0, 10, (16, 16)),
115+
ramp_tensor(0, 10, (16, 16)),
116+
2,
117+
4,
118+
),
119+
),
120+
"addmm_scalars": McuTestCase(
121+
model=CortexMAddmm(),
122+
example_inputs=(
123+
ramp_tensor(0, 10, (1, 16)),
124+
ramp_tensor(0, 10, (16, 16)),
125+
ramp_tensor(0, 10, (16, 16)),
126+
),
127+
),
128+
"@-operator": McuTestCase(
129+
model=CortexMAt(),
130+
example_inputs=(
131+
ramp_tensor(0, 10, (1, 16)),
132+
ramp_tensor(0, 10, (16, 16)),
133+
),
134+
),
135+
"matmul": McuTestCase(
136+
model=CortexMMatmul(),
137+
example_inputs=(
138+
ramp_tensor(0, 10, (1, 16)),
139+
ramp_tensor(0, 10, (16, 16)),
140+
),
141+
),
142+
"linear_rank1": McuTestCase(
143+
model=CortexMLinear(2, 3),
144+
example_inputs=(ramp_tensor(-1, 1, (2,)),),
145+
),
146+
"linear_rank2_pos": McuTestCase(
147+
model=CortexMLinear(8, 3),
148+
example_inputs=(ramp_tensor(0, 10, (2, 8)),),
149+
),
150+
"linear_rank3_neg": McuTestCase(
151+
model=CortexMLinear(5, 3),
152+
example_inputs=(ramp_tensor(-40, 0, (4, 2, 5)),),
153+
),
154+
"linear_rank4": McuTestCase(
155+
model=CortexMLinear(16, 32),
156+
example_inputs=(ramp_tensor(-100, 100, (2, 1, 2, 16)),),
157+
),
158+
"linear_rank5": McuTestCase(
159+
model=CortexMLinear(4, 3),
160+
example_inputs=(ramp_tensor(-2, 2, (5, 2, 1, 2, 4)),),
161+
),
162+
"linear_bias": McuTestCase(
163+
model=CortexMLinearBias(61, 37),
164+
example_inputs=(ramp_tensor(0, 10, (8, 61)),),
165+
),
166+
}
167+
168+
dialect_xfails = {
169+
"mm": ("torch.mm ops are currently not quantized", RuntimeError),
170+
"bmm": ("torch.bmm ops are currently not quantized", RuntimeError),
171+
"addmm": ("torch.addmm ops are currently not quantized", RuntimeError),
172+
"addmm_scalars": ("torch.addmm ops are currently not quantized", RuntimeError),
173+
"matmul": ("torch.matmul ops are currently not quantized", RuntimeError),
174+
"@-operator": ("@ ops are currently not quantized", RuntimeError),
175+
"linear_rank1": ("Only rank 2 linear ops are fused currently", RuntimeError),
176+
"linear_rank2_pos": ("name 'int32' is not defined", NameError),
177+
"linear_rank3_neg": ("Only rank 2 linear ops are fused currently", RuntimeError),
178+
"linear_rank4": ("Only rank 2 linear ops are fused currently", RuntimeError),
179+
"linear_rank5": ("Only rank 2 linear ops are fused currently", RuntimeError),
180+
"linear_bias": ("name 'int32' is not defined", NameError),
181+
}
182+
183+
184+
@parametrize("test_case", test_cases, dialect_xfails)
185+
def test_dialect_linear(test_case):
186+
tester = CortexMTester(test_case.model, test_case.example_inputs)
187+
tester.test_dialect(
188+
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
189+
)
190+
191+
192+
implementation_xfails = {
193+
"mm": ("torch.mm ops are currently not quantized", RuntimeError),
194+
"bmm": ("torch.bmm ops are currently not quantized", RuntimeError),
195+
"addmm": ("torch.addmm ops are currently not quantized", RuntimeError),
196+
"addmm_scalars": ("torch.addmm ops are currently not quantized", RuntimeError),
197+
"matmul": ("torch.matmul ops are currently not quantized", RuntimeError),
198+
"@-operator": ("@ ops are currently not quantized", RuntimeError),
199+
"linear_rank1": ("Only rank 2 linear ops are fused currently", RuntimeError),
200+
"linear_rank2_pos": ("Output 0 does not match reference output.", AssertionError),
201+
"linear_rank3_neg": ("Only rank 2 linear ops are fused currently", RuntimeError),
202+
"linear_rank4": ("Only rank 2 linear ops are fused currently", RuntimeError),
203+
"linear_rank5": ("Only rank 2 linear ops are fused currently", RuntimeError),
204+
"linear_bias": ("Output 0 does not match reference output.", AssertionError),
205+
}
206+
207+
208+
@parametrize("test_case", test_cases, implementation_xfails)
209+
def test_implementation_linear(test_case):
210+
tester = CortexMTester(test_case.model, test_case.example_inputs)
211+
tester.test_implementation()

0 commit comments

Comments
 (0)