1111import torch
1212
1313from executorch .backends .arm .test import common , conftest
14-
1514from executorch .backends .arm .test .tester .arm_tester import ArmTester
1615
17- from executorch .exir import EdgeCompileConfig
18-
1916from torch .nn .quantizable .modules import rnn
2017
2118
2219class TestLSTM (unittest .TestCase ):
23- """Tests LSTM."""
20+ """Tests quantizable LSTM module ."""
2421
22+ """
23+ Currently only the quantizable LSTM module has been verified with the arm backend.
24+ There may be plans to update this to use torch.nn.LSTM.
25+ TODO: MLETORCH-622
26+ """
2527 lstm = rnn .LSTM (10 , 20 , 2 )
2628 lstm = lstm .eval ()
2729
@@ -31,10 +33,6 @@ class TestLSTM(unittest.TestCase):
3133
3234 model_inputs = (input_tensor , (h0 , c0 ))
3335
34- _edge_compile_config = EdgeCompileConfig (
35- _skip_dim_order = True , # TODO(T182928844): Delegate dim order op to backend.
36- )
37-
3836 def test_lstm_tosa_MI (self ):
3937 (
4038 ArmTester (
@@ -43,7 +41,8 @@ def test_lstm_tosa_MI(self):
4341 compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" ),
4442 )
4543 .export ()
46- .to_edge_transform_and_lower (edge_compile_config = self ._edge_compile_config )
44+ .to_edge_transform_and_lower ()
45+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
4746 .to_executorch ()
4847 .run_method_and_compare_outputs (inputs = self .model_inputs )
4948 )
@@ -57,7 +56,8 @@ def test_lstm_tosa_BI(self):
5756 )
5857 .quantize ()
5958 .export ()
60- .to_edge_transform_and_lower (edge_compile_config = self ._edge_compile_config )
59+ .to_edge_transform_and_lower ()
60+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
6161 .to_executorch ()
6262 .run_method_and_compare_outputs (atol = 3e-1 , qtol = 1 , inputs = self .model_inputs )
6363 )
@@ -71,7 +71,8 @@ def test_lstm_u55_BI(self):
7171 )
7272 .quantize ()
7373 .export ()
74- .to_edge_transform_and_lower (edge_compile_config = self ._edge_compile_config )
74+ .to_edge_transform_and_lower ()
75+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7576 .to_executorch ()
7677 .serialize ()
7778 )
@@ -89,7 +90,8 @@ def test_lstm_u85_BI(self):
8990 )
9091 .quantize ()
9192 .export ()
92- .to_edge_transform_and_lower (edge_compile_config = self ._edge_compile_config )
93+ .to_edge_transform_and_lower ()
94+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
9395 .to_executorch ()
9496 .serialize ()
9597 )
0 commit comments