Skip to content

Commit f556b96

Browse files
authored
[Backend Tester] Add LSTM tests (#13238)
Add tests for the LSTM module. This is done in the context of #12898.
1 parent bd793c3 commit f556b96

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class Model(torch.nn.Module):
21+
def __init__(
22+
self,
23+
input_size=64,
24+
hidden_size=32,
25+
num_layers=1,
26+
bias=True,
27+
batch_first=True,
28+
dropout=0.0,
29+
bidirectional=False,
30+
):
31+
super().__init__()
32+
self.lstm = torch.nn.LSTM(
33+
input_size=input_size,
34+
hidden_size=hidden_size,
35+
num_layers=num_layers,
36+
bias=bias,
37+
batch_first=batch_first,
38+
dropout=dropout,
39+
bidirectional=bidirectional,
40+
)
41+
42+
def forward(self, x):
43+
return self.lstm(x)[0] # Return only the output, not the hidden states
44+
45+
46+
@operator_test
47+
class LSTM(OperatorTest):
48+
@dtype_test
49+
def test_lstm_dtype(self, flow: TestFlow, dtype) -> None:
50+
self._test_op(
51+
Model(num_layers=2).to(dtype),
52+
((torch.rand(1, 10, 64) * 10).to(dtype),), # (batch=1, seq_len, input_size)
53+
flow,
54+
)
55+
56+
@dtype_test
57+
def test_lstm_no_bias_dtype(self, flow: TestFlow, dtype) -> None:
58+
self._test_op(
59+
Model(num_layers=2, bias=False).to(dtype),
60+
((torch.rand(1, 10, 64) * 10).to(dtype),),
61+
flow,
62+
)
63+
64+
def test_lstm_feature_sizes(self, flow: TestFlow) -> None:
65+
self._test_op(
66+
Model(input_size=32, hidden_size=16),
67+
(torch.randn(1, 8, 32),), # (batch=1, seq_len, input_size)
68+
flow,
69+
)
70+
self._test_op(
71+
Model(input_size=128, hidden_size=64),
72+
(torch.randn(1, 12, 128),),
73+
flow,
74+
)
75+
self._test_op(
76+
Model(input_size=256, hidden_size=128),
77+
(torch.randn(1, 6, 256),),
78+
flow,
79+
)
80+
self._test_op(
81+
Model(input_size=16, hidden_size=32),
82+
(torch.randn(1, 5, 16),),
83+
flow,
84+
)
85+
86+
def test_lstm_batch_sizes(self, flow: TestFlow) -> None:
87+
self._test_op(
88+
Model(),
89+
(torch.randn(8, 10, 64),),
90+
flow,
91+
)
92+
self._test_op(
93+
Model(),
94+
(torch.randn(32, 10, 64),),
95+
flow,
96+
)
97+
self._test_op(
98+
Model(),
99+
(torch.randn(100, 10, 64),),
100+
flow,
101+
)
102+
103+
def test_lstm_seq_lengths(self, flow: TestFlow) -> None:
104+
self._test_op(
105+
Model(),
106+
(torch.randn(1, 5, 64),),
107+
flow,
108+
)
109+
self._test_op(
110+
Model(),
111+
(torch.randn(1, 20, 64),),
112+
flow,
113+
)
114+
self._test_op(
115+
Model(),
116+
(torch.randn(1, 50, 64),),
117+
flow,
118+
)
119+
120+
def test_lstm_batch_first_false(self, flow: TestFlow) -> None:
121+
self._test_op(
122+
Model(batch_first=False),
123+
(torch.randn(10, 1, 64),), # (seq_len, batch=1, input_size)
124+
flow,
125+
)
126+
127+
def test_lstm_num_layers(self, flow: TestFlow) -> None:
128+
self._test_op(
129+
Model(num_layers=2),
130+
(torch.randn(1, 10, 64),),
131+
flow,
132+
)
133+
self._test_op(
134+
Model(num_layers=3),
135+
(torch.randn(1, 10, 64),),
136+
flow,
137+
)
138+
139+
def test_lstm_bidirectional(self, flow: TestFlow) -> None:
140+
self._test_op(
141+
Model(bidirectional=True),
142+
(torch.randn(1, 10, 64),),
143+
flow,
144+
)
145+
146+
def test_lstm_with_dropout(self, flow: TestFlow) -> None:
147+
# Note: Dropout is only effective with num_layers > 1
148+
self._test_op(
149+
Model(num_layers=2, dropout=0.2),
150+
(torch.randn(1, 10, 64),),
151+
flow,
152+
)
153+
154+
def test_lstm_with_initial_states(self, flow: TestFlow) -> None:
155+
# Create a model that accepts initial states
156+
class ModelWithStates(torch.nn.Module):
157+
def __init__(self):
158+
super().__init__()
159+
self.lstm = torch.nn.LSTM(
160+
input_size=64,
161+
hidden_size=32,
162+
num_layers=2,
163+
batch_first=True,
164+
)
165+
166+
def forward(self, x, h0, c0):
167+
return self.lstm(x, (h0, c0))[0] # Return only the output
168+
169+
batch_size = 1
170+
num_layers = 2
171+
hidden_size = 32
172+
173+
self._test_op(
174+
ModelWithStates(),
175+
(
176+
torch.randn(batch_size, 10, 64), # input
177+
torch.randn(num_layers, batch_size, hidden_size), # h0
178+
torch.randn(num_layers, batch_size, hidden_size), # c0
179+
),
180+
flow,
181+
)
182+
183+
def test_lstm_return_hidden_states(self, flow: TestFlow) -> None:
184+
# Create a model that returns both output and hidden states
185+
class ModelWithHiddenStates(torch.nn.Module):
186+
def __init__(self):
187+
super().__init__()
188+
self.lstm = torch.nn.LSTM(
189+
input_size=64,
190+
hidden_size=32,
191+
num_layers=2,
192+
batch_first=True,
193+
)
194+
195+
def forward(self, x):
196+
# Return the complete output tuple: (output, (h_n, c_n))
197+
output, (h_n, c_n) = self.lstm(x)
198+
return output, h_n, c_n
199+
200+
batch_size = 1
201+
seq_len = 10
202+
input_size = 64
203+
204+
self._test_op(
205+
ModelWithHiddenStates(),
206+
(torch.randn(batch_size, seq_len, input_size),),
207+
flow,
208+
)

0 commit comments

Comments
 (0)