Skip to content

Commit 999fef6

Browse files
committed
Add test
1 parent 5878332 commit 999fef6

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

tests/tests_pytorch/helpers/advanced_models.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,54 @@ def configure_optimizers(self):
219219

220220
def train_dataloader(self):
221221
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
222+
223+
224+
class TBPTTModule(LightningModule):
225+
def __init__(self):
226+
super().__init__()
227+
228+
self.batch_size = 10
229+
self.in_features = 10
230+
self.out_features = 5
231+
self.hidden_dim = 20
232+
233+
self.automatic_optimization = False
234+
self.truncated_bptt_steps = 10
235+
236+
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
237+
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
238+
239+
def forward(self, x, hs):
240+
seq, hs = self.rnn(x, hs)
241+
return self.linear_out(seq), hs
242+
243+
def training_step(self, batch, batch_idx):
244+
x, y = batch
245+
split_x, split_y = [
246+
x.tensor_split(self.truncated_bptt_steps, dim=1),
247+
y.tensor_split(self.truncated_bptt_steps, dim=1),
248+
]
249+
250+
hiddens = None
251+
optimizer = self.optimizers()
252+
losses = []
253+
254+
for x, y in zip(split_x, split_y):
255+
y_pred, hiddens = self(x, hiddens)
256+
loss = F.mse_loss(y_pred, y)
257+
258+
optimizer.zero_grad()
259+
self.manual_backward(loss)
260+
optimizer.step()
261+
262+
# "Truncate"
263+
hiddens = [h.detach() for h in hiddens]
264+
losses.append(loss.detach())
265+
266+
return
267+
268+
def configure_optimizers(self):
269+
return torch.optim.Adam(self.parameters(), lr=0.001)
270+
271+
def train_dataloader(self):
272+
return DataLoader(AverageDataset(), batch_size=self.batch_size)

tests/tests_pytorch/helpers/test_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from lightning.pytorch import Trainer
1818
from lightning.pytorch.demos.boring_classes import BoringModel
1919

20-
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
20+
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
2121
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
2222
from tests_pytorch.helpers.runif import RunIf
2323
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
@@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
4949
model.to_torchscript()
5050
if data_class:
5151
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
52+
53+
54+
def test_tbptt(tmp_path):
55+
model = TBPTTModule()
56+
57+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
58+
trainer.fit(model)

0 commit comments

Comments
 (0)