Skip to content

Commit 58f0ea6

Browse files
committed
Make example self-contained
1 parent 4e18df1 commit 58f0ea6

File tree

1 file changed

+53
-14
lines changed

1 file changed

+53
-14
lines changed

docs/source-pytorch/common/tbptt.rst

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,52 +12,91 @@ hidden states should be kept in-between each time-dimension split.
1212
.. code-block:: python
1313
1414
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
1517
import torch.optim as optim
18+
from torch.utils.data import Dataset, DataLoader
1619
import pytorch_lightning as pl
1720
from pytorch_lightning import LightningModule
1821
22+
23+
class AverageDataset(Dataset):
24+
def __init__(self, dataset_len=300, sequence_len=100):
25+
self.dataset_len = dataset_len
26+
self.sequence_len = sequence_len
27+
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
28+
top, bottom = self.input_seq.chunk(2, -1)
29+
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
30+
31+
def __len__(self):
32+
return self.dataset_len
33+
34+
def __getitem__(self, item):
35+
return self.input_seq[item], self.output_seq[item]
36+
37+
1938
class LitModel(LightningModule):
2039
2140
def __init__(self):
2241
super().__init__()
2342
43+
self.batch_size = 10
44+
self.in_features = 10
45+
self.out_features = 5
46+
self.hidden_dim = 20
47+
2448
# 1. Switch to manual optimization
2549
self.automatic_optimization = False
26-
2750
self.truncated_bptt_steps = 10
28-
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
51+
52+
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
53+
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
54+
55+
def forward(self, x, hs):
56+
seq, hs = self.rnn(x, hs)
57+
return self.linear_out(seq), hs
2958
3059
# 2. Remove the `hiddens` argument
3160
def training_step(self, batch, batch_idx):
32-
3361
# 3. Split the batch in chunks along the time dimension
34-
split_batches = split_batch(batch, self.truncated_bptt_steps)
62+
x, y = batch
63+
split_x, split_y = [
64+
x.tensor_split(self.truncated_bptt_steps, dim=1),
65+
y.tensor_split(self.truncated_bptt_steps, dim=1)
66+
]
3567
36-
batch_size = 10
37-
hidden_dim = 20
38-
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
39-
# get optimizer
68+
hiddens = None
4069
optimizer = self.optimizers()
70+
losses = []
4171
42-
for split_batch in range(split_batches):
43-
# 4. Perform the optimization in a loop
44-
loss, hiddens = self.my_rnn(split_batch, hiddens)
72+
# 4. Perform the optimization in a loop
73+
for x, y in zip(split_x, split_y):
74+
y_pred, hiddens = self(x, hiddens)
75+
loss = F.mse_loss(y_pred, y)
4576
4677
optimizer.zero_grad()
4778
self.manual_backward(loss)
4879
optimizer.step()
4980
5081
# 5. "Truncate"
51-
hiddens = hiddens.detach()
82+
hiddens = [h.detach() for h in hiddens]
83+
losses.append(loss.detach())
84+
85+
avg_loss = sum(losses) / len(losses)
86+
self.log("train_loss", avg_loss, prog_bar=True)
5287
5388
# 6. Remove the return of `hiddens`
5489
# Returning loss in manual optimization is not needed
5590
return None
5691
5792
def configure_optimizers(self):
58-
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
93+
return optim.Adam(self.parameters(), lr=0.001)
94+
95+
def train_dataloader(self):
96+
return DataLoader(AverageDataset(), batch_size=self.batch_size)
97+
5998
6099
if __name__ == "__main__":
61100
model = LitModel()
62101
trainer = pl.Trainer(max_epochs=5)
63-
trainer.fit(model, train_dataloader) # Define your own dataloader
102+
trainer.fit(model)

0 commit comments

Comments
 (0)