@@ -12,52 +12,91 @@ hidden states should be kept in-between each time-dimension split.
1212.. code-block :: python
1313
1414 import torch
15+ import torch.nn as nn
16+ import torch.nn.functional as F
1517 import torch.optim as optim
18+ from torch.utils.data import Dataset, DataLoader
1619 import pytorch_lightning as pl
1720 from pytorch_lightning import LightningModule
1821
22+
23+ class AverageDataset (Dataset ):
24+ def __init__ (self , dataset_len = 300 , sequence_len = 100 ):
25+ self .dataset_len = dataset_len
26+ self .sequence_len = sequence_len
27+ self .input_seq = torch.randn(dataset_len, sequence_len, 10 )
28+ top, bottom = self .input_seq.chunk(2 , - 1 )
29+ self .output_seq = top + bottom.roll(shifts = 1 , dims = - 1 )
30+
31+ def __len__ (self ):
32+ return self .dataset_len
33+
34+ def __getitem__ (self , item ):
35+ return self .input_seq[item], self .output_seq[item]
36+
37+
1938 class LitModel (LightningModule ):
2039
2140 def __init__ (self ):
2241 super ().__init__ ()
2342
43+ self .batch_size = 10
44+ self .in_features = 10
45+ self .out_features = 5
46+ self .hidden_dim = 20
47+
2448 # 1. Switch to manual optimization
2549 self .automatic_optimization = False
26-
2750 self .truncated_bptt_steps = 10
28- self .my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
51+
52+ self .rnn = nn.LSTM(self .in_features, self .hidden_dim, batch_first = True )
53+ self .linear_out = nn.Linear(in_features = self .hidden_dim, out_features = self .out_features)
54+
55+ def forward (self , x , hs ):
56+ seq, hs = self .rnn(x, hs)
57+ return self .linear_out(seq), hs
2958
3059 # 2. Remove the `hiddens` argument
3160 def training_step (self , batch , batch_idx ):
32-
3361 # 3. Split the batch in chunks along the time dimension
34- split_batches = split_batch(batch, self .truncated_bptt_steps)
62+ x, y = batch
63+ split_x, split_y = [
64+ x.tensor_split(self .truncated_bptt_steps, dim = 1 ),
65+ y.tensor_split(self .truncated_bptt_steps, dim = 1 )
66+ ]
3567
36- batch_size = 10
37- hidden_dim = 20
38- hiddens = torch.zeros(1 , batch_size, hidden_dim, device = self .device)
39- # get optimizer
68+ hiddens = None
4069 optimizer = self .optimizers()
70+ losses = []
4171
42- for split_batch in range (split_batches):
43- # 4. Perform the optimization in a loop
44- loss, hiddens = self .my_rnn(split_batch, hiddens)
72+ # 4. Perform the optimization in a loop
73+ for x, y in zip (split_x, split_y):
74+ y_pred, hiddens = self (x, hiddens)
75+ loss = F.mse_loss(y_pred, y)
4576
4677 optimizer.zero_grad()
4778 self .manual_backward(loss)
4879 optimizer.step()
4980
5081 # 5. "Truncate"
51- hiddens = hiddens.detach()
82+ hiddens = [h.detach() for h in hiddens]
83+ losses.append(loss.detach())
84+
85+ avg_loss = sum (losses) / len (losses)
86+ self .log(" train_loss" , avg_loss, prog_bar = True )
5287
5388 # 6. Remove the return of `hiddens`
5489 # Returning loss in manual optimization is not needed
5590 return None
5691
5792 def configure_optimizers (self ):
58- return optim.Adam(self .my_rnn.parameters(), lr = 0.001 )
93+ return optim.Adam(self .parameters(), lr = 0.001 )
94+
95+ def train_dataloader (self ):
96+ return DataLoader(AverageDataset(), batch_size = self .batch_size)
97+
5998
6099 if __name__ == " __main__" :
61100 model = LitModel()
62101 trainer = pl.Trainer(max_epochs = 5 )
63- trainer.fit(model, train_dataloader) # Define your own dataloader
102+ trainer.fit(model)
0 commit comments