11# Copyright (c) Meta Platforms, Inc. and affiliates.
22# All rights reserved.
3+ # Copyright 2025 Arm Limited and/or its affiliates.
34#
45# This source code is licensed under the BSD-style license found in the
56# LICENSE file in the root directory of this source tree.
1516 operator_test ,
1617 OperatorTest ,
1718)
19+ from torch .nn .quantizable .modules .rnn import LSTM as QuantizableLSTM
20+
21+
22+ def _get_lstm_cls (use_quantizable_lstm : bool ):
23+ return QuantizableLSTM if use_quantizable_lstm else torch .nn .LSTM
1824
1925
2026class Model (torch .nn .Module ):
@@ -27,9 +33,11 @@ def __init__(
2733 batch_first = True ,
2834 dropout = 0.0 ,
2935 bidirectional = False ,
36+ use_quantizable_lstm : bool = False ,
3037 ):
3138 super ().__init__ ()
32- self .lstm = torch .nn .LSTM (
39+ lstm_cls = _get_lstm_cls (use_quantizable_lstm )
40+ self .lstm = lstm_cls (
3341 input_size = input_size ,
3442 hidden_size = hidden_size ,
3543 num_layers = num_layers ,
@@ -47,116 +55,144 @@ def forward(self, x):
4755class LSTM (OperatorTest ):
4856 @dtype_test
4957 def test_lstm_dtype (self , flow : TestFlow , dtype ) -> None :
58+ use_quantizable_lstm = flow .quantize
5059 self ._test_op (
51- Model (num_layers = 2 ).to (dtype ),
60+ Model (num_layers = 2 , use_quantizable_lstm = use_quantizable_lstm ).to (dtype ),
5261 ((torch .rand (1 , 10 , 64 ) * 10 ).to (dtype ),), # (batch=1, seq_len, input_size)
5362 flow ,
5463 )
5564
5665 @dtype_test
5766 def test_lstm_no_bias_dtype (self , flow : TestFlow , dtype ) -> None :
67+ use_quantizable_lstm = flow .quantize
5868 self ._test_op (
59- Model (num_layers = 2 , bias = False ).to (dtype ),
69+ Model (
70+ num_layers = 2 , bias = False , use_quantizable_lstm = use_quantizable_lstm
71+ ).to (dtype ),
6072 ((torch .rand (1 , 10 , 64 ) * 10 ).to (dtype ),),
6173 flow ,
6274 )
6375
6476 def test_lstm_feature_sizes (self , flow : TestFlow ) -> None :
77+ use_quantizable_lstm = flow .quantize
6578 self ._test_op (
66- Model (input_size = 32 , hidden_size = 16 ),
79+ Model (
80+ input_size = 32 ,
81+ hidden_size = 16 ,
82+ use_quantizable_lstm = use_quantizable_lstm ,
83+ ),
6784 (torch .randn (1 , 8 , 32 ),), # (batch=1, seq_len, input_size)
6885 flow ,
6986 )
7087 self ._test_op (
71- Model (input_size = 128 , hidden_size = 64 ),
88+ Model (
89+ input_size = 128 ,
90+ hidden_size = 64 ,
91+ use_quantizable_lstm = use_quantizable_lstm ,
92+ ),
7293 (torch .randn (1 , 12 , 128 ),),
7394 flow ,
7495 )
7596 self ._test_op (
76- Model (input_size = 256 , hidden_size = 128 ),
97+ Model (
98+ input_size = 256 ,
99+ hidden_size = 128 ,
100+ use_quantizable_lstm = use_quantizable_lstm ,
101+ ),
77102 (torch .randn (1 , 6 , 256 ),),
78103 flow ,
79104 )
80105 self ._test_op (
81- Model (input_size = 16 , hidden_size = 32 ),
106+ Model (
107+ input_size = 16 ,
108+ hidden_size = 32 ,
109+ use_quantizable_lstm = use_quantizable_lstm ,
110+ ),
82111 (torch .randn (1 , 5 , 16 ),),
83112 flow ,
84113 )
85114
86115 def test_lstm_batch_sizes (self , flow : TestFlow ) -> None :
116+ use_quantizable_lstm = flow .quantize
87117 self ._test_op (
88- Model (),
118+ Model (use_quantizable_lstm = use_quantizable_lstm ),
89119 (torch .randn (8 , 10 , 64 ),),
90120 flow ,
91121 )
92122 self ._test_op (
93- Model (),
123+ Model (use_quantizable_lstm = use_quantizable_lstm ),
94124 (torch .randn (32 , 10 , 64 ),),
95125 flow ,
96126 )
97127 self ._test_op (
98- Model (),
128+ Model (use_quantizable_lstm = use_quantizable_lstm ),
99129 (torch .randn (100 , 10 , 64 ),),
100130 flow ,
101131 )
102132
103133 def test_lstm_seq_lengths (self , flow : TestFlow ) -> None :
134+ use_quantizable_lstm = flow .quantize
104135 self ._test_op (
105- Model (),
136+ Model (use_quantizable_lstm = use_quantizable_lstm ),
106137 (torch .randn (1 , 5 , 64 ),),
107138 flow ,
108139 )
109140 self ._test_op (
110- Model (),
141+ Model (use_quantizable_lstm = use_quantizable_lstm ),
111142 (torch .randn (1 , 20 , 64 ),),
112143 flow ,
113144 )
114145 self ._test_op (
115- Model (),
146+ Model (use_quantizable_lstm = use_quantizable_lstm ),
116147 (torch .randn (1 , 50 , 64 ),),
117148 flow ,
118149 )
119150
120151 def test_lstm_batch_first_false (self , flow : TestFlow ) -> None :
152+ use_quantizable_lstm = flow .quantize
121153 self ._test_op (
122- Model (batch_first = False ),
154+ Model (batch_first = False , use_quantizable_lstm = use_quantizable_lstm ),
123155 (torch .randn (10 , 1 , 64 ),), # (seq_len, batch=1, input_size)
124156 flow ,
125157 )
126158
127159 def test_lstm_num_layers (self , flow : TestFlow ) -> None :
160+ use_quantizable_lstm = flow .quantize
128161 self ._test_op (
129- Model (num_layers = 2 ),
162+ Model (num_layers = 2 , use_quantizable_lstm = use_quantizable_lstm ),
130163 (torch .randn (1 , 10 , 64 ),),
131164 flow ,
132165 )
133166 self ._test_op (
134- Model (num_layers = 3 ),
167+ Model (num_layers = 3 , use_quantizable_lstm = use_quantizable_lstm ),
135168 (torch .randn (1 , 10 , 64 ),),
136169 flow ,
137170 )
138171
139172 def test_lstm_bidirectional (self , flow : TestFlow ) -> None :
173+ use_quantizable_lstm = flow .quantize
140174 self ._test_op (
141- Model (bidirectional = True ),
175+ Model (bidirectional = True , use_quantizable_lstm = use_quantizable_lstm ),
142176 (torch .randn (1 , 10 , 64 ),),
143177 flow ,
144178 )
145179
146180 def test_lstm_with_dropout (self , flow : TestFlow ) -> None :
147181 # Note: Dropout is only effective with num_layers > 1
182+ use_quantizable_lstm = flow .quantize
148183 self ._test_op (
149- Model (num_layers = 2 , dropout = 0.2 ),
184+ Model (num_layers = 2 , dropout = 0.2 , use_quantizable_lstm = use_quantizable_lstm ),
150185 (torch .randn (1 , 10 , 64 ),),
151186 flow ,
152187 )
153188
154189 def test_lstm_with_initial_states (self , flow : TestFlow ) -> None :
155190 # Create a model that accepts initial states
156191 class ModelWithStates (torch .nn .Module ):
157- def __init__ (self ):
192+ def __init__ (self , use_quantizable_lstm : bool = False ):
158193 super ().__init__ ()
159- self .lstm = torch .nn .LSTM (
194+ lstm_cls = _get_lstm_cls (use_quantizable_lstm )
195+ self .lstm = lstm_cls (
160196 input_size = 64 ,
161197 hidden_size = 32 ,
162198 num_layers = 2 ,
@@ -169,9 +205,10 @@ def forward(self, x, h0, c0):
169205 batch_size = 1
170206 num_layers = 2
171207 hidden_size = 32
208+ use_quantizable_lstm = flow .quantize
172209
173210 self ._test_op (
174- ModelWithStates (),
211+ ModelWithStates (use_quantizable_lstm = use_quantizable_lstm ),
175212 (
176213 torch .randn (batch_size , 10 , 64 ), # input
177214 torch .randn (num_layers , batch_size , hidden_size ), # h0
@@ -183,9 +220,10 @@ def forward(self, x, h0, c0):
183220 def test_lstm_return_hidden_states (self , flow : TestFlow ) -> None :
184221 # Create a model that returns both output and hidden states
185222 class ModelWithHiddenStates (torch .nn .Module ):
186- def __init__ (self ):
223+ def __init__ (self , use_quantizable_lstm : bool = False ):
187224 super ().__init__ ()
188- self .lstm = torch .nn .LSTM (
225+ lstm_cls = _get_lstm_cls (use_quantizable_lstm )
226+ self .lstm = lstm_cls (
189227 input_size = 64 ,
190228 hidden_size = 32 ,
191229 num_layers = 2 ,
@@ -200,9 +238,10 @@ def forward(self, x):
200238 batch_size = 1
201239 seq_len = 10
202240 input_size = 64
241+ use_quantizable_lstm = flow .quantize
203242
204243 self ._test_op (
205- ModelWithHiddenStates (),
244+ ModelWithHiddenStates (use_quantizable_lstm = use_quantizable_lstm ),
206245 (torch .randn (batch_size , seq_len , input_size ),),
207246 flow ,
208247 )
0 commit comments