Skip to content

Commit 4c58010

Browse files
Cortex_m backend: Add MV3 and lstm test (#15393)
Adding as proof of concept, skipping until functionally supported. Note that LSTM is one op in CMSIS-NN, which is why it is added in the op folder. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 4675292 commit 4c58010

File tree

4 files changed

+189
-2
lines changed

4 files changed

+189
-2
lines changed

backends/cortex_m/test/build_test_runner.sh

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,20 @@ et_root_dir=$(realpath "${script_dir}/../../..")
1414
build_executorch="${et_root_dir}/backends/arm/scripts/build_executorch.sh"
1515
${build_executorch}
1616

17-
# Build executor runner with all portable ops selected and semi hosting
17+
# Build executor runner with selected aten ops and semi hosting
1818
build_dir="${et_root_dir}/arm_test"
1919
build_executor_runner="${et_root_dir}/backends/arm/scripts/build_executor_runner.sh"
2020
build_root_test_dir="${et_root_dir}/arm_test/arm_semihosting_executor_runner_corstone-300"
2121

22-
${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --output="${build_root_test_dir}"
22+
select_ops_list="\
23+
aten::add.out,\
24+
aten::clamp.out,\
25+
aten::convolution.out,\
26+
aten::div.out,\
27+
aten::mean.out,\
28+
aten::mul.out,\
29+
aten::relu.out,\
30+
aten::view_copy.out,\
31+
dim_order_ops::_to_dim_order_copy.out"
32+
33+
${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import pytest
7+
import torch
8+
9+
from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase
10+
from torchvision import models
11+
12+
13+
# TODO: Update as more ops are converted by CMSIS-NN ops.
14+
ops_before_transforms: dict[str, int] = {
15+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 34,
16+
"executorch_exir_dialects_edge__ops_aten_addmm_default": 2,
17+
"executorch_exir_dialects_edge__ops_aten_clamp_default": 56,
18+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 52,
19+
"executorch_exir_dialects_edge__ops_aten_div_Tensor": 28,
20+
"executorch_exir_dialects_edge__ops_aten_mean_dim": 10,
21+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28,
22+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2,
23+
"executorch_exir_dialects_edge__ops_aten_relu_default": 14,
24+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
25+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56,
26+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 178,
27+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 109,
28+
}
29+
ops_after_transforms: dict[str, int] = {
30+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 28, # Not lowered due to broadcasting
31+
"executorch_exir_dialects_edge__ops_aten_addmm_default": 0,
32+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 6,
33+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 2,
34+
"executorch_exir_dialects_edge__ops_aten_clamp_default": 56,
35+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 52,
36+
"executorch_exir_dialects_edge__ops_aten_div_Tensor": 28,
37+
"executorch_exir_dialects_edge__ops_aten_mean_dim": 10,
38+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28,
39+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 0,
40+
"executorch_exir_dialects_edge__ops_aten_relu_default": 14,
41+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
42+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56,
43+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 0,
44+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 0,
45+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 162,
46+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 101,
47+
}
48+
49+
model = models.mobilenet_v3_small(weights=None)
50+
example_input = torch.randn(1, 3, 224, 224)
51+
52+
53+
test_cases = {
54+
"mobilenet_v3_small": McuTestCase(
55+
model=models.mobilenet_v3_small(weights=None),
56+
example_inputs=(example_input,),
57+
),
58+
}
59+
60+
61+
@pytest.mark.skip("Skip until add + linear fix are upstreamed.")
62+
def test_dialect_mv3(test_case):
63+
tester = CortexMTester(test_case.model, test_case.example_inputs)
64+
tester.test_dialect(
65+
ops_before_transforms,
66+
ops_after_transforms,
67+
qtol=1,
68+
)
69+
70+
71+
@pytest.mark.skip("Skip until add + linear fix are upstreamed.")
72+
def test_implementation_mv3(test_case):
73+
tester = CortexMTester(test_case.model, test_case.example_inputs)
74+
tester.test_implementation(qtol=1)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import pytest
8+
import torch
9+
from executorch.backends.cortex_m.test.tester import (
10+
CortexMTester,
11+
McuTestCase,
12+
ramp_tensor,
13+
)
14+
15+
16+
class CortexMLSTM(torch.nn.Module):
17+
ops_before_transforms = {
18+
"executorch_exir_dialects_edge__ops_aten_full_default": 2,
19+
"executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 4,
20+
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2,
21+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 6,
22+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 3,
23+
"executorch_exir_dialects_edge__ops_aten_addmm_default": 3,
24+
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 2,
25+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4,
26+
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 2,
27+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 6,
28+
"executorch_exir_dialects_edge__ops_aten_tanh_default": 4,
29+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 6,
30+
"executorch_exir_dialects_edge__ops_aten_cat_default": 1,
31+
}
32+
33+
ops_after_transforms = {}
34+
35+
def __init__(self, input_size: int = 4, hidden_size: int = 3) -> None:
36+
super().__init__()
37+
self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size)
38+
39+
def forward(self, x: torch.Tensor) -> torch.Tensor:
40+
y, _ = self.lstm(x)
41+
return y
42+
43+
44+
class CortexMQuantizableLSTM(torch.nn.Module):
45+
ops_before_transforms = {
46+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4,
47+
"executorch_exir_dialects_edge__ops_aten_addmm_default": 4,
48+
"executorch_exir_dialects_edge__ops_aten_cat_default": 1,
49+
"executorch_exir_dialects_edge__ops_aten_full_default": 1,
50+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 6,
51+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 4,
52+
"executorch_exir_dialects_edge__ops_aten_select_copy_int": 2,
53+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 6,
54+
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 2,
55+
"executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1,
56+
"executorch_exir_dialects_edge__ops_aten_tanh_default": 4,
57+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
58+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 34,
59+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 27,
60+
}
61+
62+
ops_after_transforms = {}
63+
64+
def __init__(self, input_size: int = 4, hidden_size: int = 3) -> None:
65+
super().__init__()
66+
self.lstm = torch.ao.nn.quantizable.LSTM(
67+
input_size=input_size, hidden_size=hidden_size
68+
)
69+
70+
def forward(self, x: torch.Tensor) -> torch.Tensor:
71+
y, _ = self.lstm(x)
72+
return y
73+
74+
75+
test_cases = {
76+
"lstm_fp32": McuTestCase(
77+
model=CortexMLSTM(),
78+
example_inputs=(ramp_tensor(-1, 1, (2, 1, 4)),),
79+
),
80+
"lstm_quantizable": McuTestCase(
81+
model=CortexMQuantizableLSTM(),
82+
example_inputs=(ramp_tensor(-1, 1, (2, 1, 4)),),
83+
),
84+
}
85+
86+
87+
@pytest.mark.skip("Not implemented yet.")
88+
def test_dialect_lstm(test_case: McuTestCase) -> None:
89+
tester = CortexMTester(test_case.model, test_case.example_inputs)
90+
tester.test_dialect(
91+
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
92+
)
93+
94+
95+
@pytest.mark.skip("Not implemented yet.")
96+
def test_implementation_lstm(test_case: McuTestCase) -> None:
97+
tester = CortexMTester(test_case.model, test_case.example_inputs)
98+
tester.test_implementation()

0 commit comments

Comments
 (0)