Skip to content

Commit 2516049

Browse files
authored
Arm backend: Add LSTM unit test to arm backend (#7787)
Add LSTM unit test to arm backend Adding check that number of delegates=1
1 parent a31d97b commit 2516049

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
import unittest
10+
11+
import torch
12+
13+
from executorch.backends.arm.test import common, conftest
14+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
15+
16+
from torch.nn.quantizable.modules import rnn
17+
18+
19+
class TestLSTM(unittest.TestCase):
20+
"""Tests quantizable LSTM module."""
21+
22+
"""
23+
Currently only the quantizable LSTM module has been verified with the arm backend.
24+
There may be plans to update this to use torch.nn.LSTM.
25+
TODO: MLETORCH-622
26+
"""
27+
lstm = rnn.LSTM(10, 20, 2)
28+
lstm = lstm.eval()
29+
30+
input_tensor = torch.randn(5, 3, 10)
31+
h0 = torch.randn(2, 3, 20)
32+
c0 = torch.randn(2, 3, 20)
33+
34+
model_inputs = (input_tensor, (h0, c0))
35+
36+
def test_lstm_tosa_MI(self):
37+
(
38+
ArmTester(
39+
self.lstm,
40+
example_inputs=self.model_inputs,
41+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
42+
)
43+
.export()
44+
.to_edge_transform_and_lower()
45+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
46+
.to_executorch()
47+
.run_method_and_compare_outputs(inputs=self.model_inputs)
48+
)
49+
50+
def test_lstm_tosa_BI(self):
51+
(
52+
ArmTester(
53+
self.lstm,
54+
example_inputs=self.model_inputs,
55+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
56+
)
57+
.quantize()
58+
.export()
59+
.to_edge_transform_and_lower()
60+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
61+
.to_executorch()
62+
.run_method_and_compare_outputs(atol=3e-1, qtol=1, inputs=self.model_inputs)
63+
)
64+
65+
def test_lstm_u55_BI(self):
66+
tester = (
67+
ArmTester(
68+
self.lstm,
69+
example_inputs=self.model_inputs,
70+
compile_spec=common.get_u55_compile_spec(),
71+
)
72+
.quantize()
73+
.export()
74+
.to_edge_transform_and_lower()
75+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
76+
.to_executorch()
77+
.serialize()
78+
)
79+
if conftest.is_option_enabled("corstone_fvp"):
80+
tester.run_method_and_compare_outputs(
81+
atol=3e-1, qtol=1, inputs=self.model_inputs
82+
)
83+
84+
def test_lstm_u85_BI(self):
85+
tester = (
86+
ArmTester(
87+
self.lstm,
88+
example_inputs=self.model_inputs,
89+
compile_spec=common.get_u85_compile_spec(),
90+
)
91+
.quantize()
92+
.export()
93+
.to_edge_transform_and_lower()
94+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
95+
.to_executorch()
96+
.serialize()
97+
)
98+
if conftest.is_option_enabled("corstone_fvp"):
99+
tester.run_method_and_compare_outputs(
100+
atol=3e-1, qtol=1, inputs=self.model_inputs
101+
)

0 commit comments

Comments
 (0)