Skip to content

Commit 6cb41ed

Browse files
committed
Addressing PR comments for LSTM
- Adding check that number of delegates=1
1 parent 14e9804 commit 6cb41ed

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

backends/arm/test/models/test_lstm_arm.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@
1111
import torch
1212

1313
from executorch.backends.arm.test import common, conftest
14-
1514
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1615

17-
from executorch.exir import EdgeCompileConfig
18-
1916
from torch.nn.quantizable.modules import rnn
2017

2118

2219
class TestLSTM(unittest.TestCase):
23-
"""Tests LSTM."""
20+
"""Tests quantizable LSTM module."""
2421

22+
"""
23+
Currently only the quantizable LSTM module has been verified with the arm backend.
24+
There may be plans to update this to use torch.nn.LSTM.
25+
TODO: MLETORCH-622
26+
"""
2527
lstm = rnn.LSTM(10, 20, 2)
2628
lstm = lstm.eval()
2729

@@ -31,10 +33,6 @@ class TestLSTM(unittest.TestCase):
3133

3234
model_inputs = (input_tensor, (h0, c0))
3335

34-
_edge_compile_config = EdgeCompileConfig(
35-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
36-
)
37-
3836
def test_lstm_tosa_MI(self):
3937
(
4038
ArmTester(
@@ -43,7 +41,8 @@ def test_lstm_tosa_MI(self):
4341
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
4442
)
4543
.export()
46-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
44+
.to_edge_transform_and_lower()
45+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
4746
.to_executorch()
4847
.run_method_and_compare_outputs(inputs=self.model_inputs)
4948
)
@@ -57,7 +56,8 @@ def test_lstm_tosa_BI(self):
5756
)
5857
.quantize()
5958
.export()
60-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
59+
.to_edge_transform_and_lower()
60+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6161
.to_executorch()
6262
.run_method_and_compare_outputs(atol=3e-1, qtol=1, inputs=self.model_inputs)
6363
)
@@ -71,7 +71,8 @@ def test_lstm_u55_BI(self):
7171
)
7272
.quantize()
7373
.export()
74-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
74+
.to_edge_transform_and_lower()
75+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
7576
.to_executorch()
7677
.serialize()
7778
)
@@ -89,7 +90,8 @@ def test_lstm_u85_BI(self):
8990
)
9091
.quantize()
9192
.export()
92-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
93+
.to_edge_transform_and_lower()
94+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9395
.to_executorch()
9496
.serialize()
9597
)

0 commit comments

Comments
 (0)