@@ -10,6 +10,9 @@ hidden states should be kept in-between each time-dimension split.
1010
1111
1212.. code-block :: python
13+ import torch.optim as optim
14+ import pytorch_lightning as pl
15+ from pytorch_lightning import LightningModule
1316
1417 class LitModel (LightningModule ):
1518
@@ -20,25 +23,33 @@ hidden states should be kept in-between each time-dimension split.
2023 self .automatic_optimization = False
2124
2225 self .truncated_bptt_steps = 10
23- self .my_rnn = ...
26+ self .my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
2427
2528 # 2. Remove the `hiddens` argument
2629 def training_step (self , batch , batch_idx ):
2730
2831 # 3. Split the batch in chunks along the time dimension
2932 split_batches = split_batch(batch, self .truncated_bptt_steps)
3033
31- hiddens = ... # 3. Choose the initial hidden state
34+ hiddens = ... # Choose the initial hidden state
3235 for split_batch in range (split_batches):
3336 # 4. Perform the optimization in a loop
3437 loss, hiddens = self .my_rnn(split_batch, hiddens)
3538 self .backward(loss)
36- optimizer.step()
37- optimizer.zero_grad()
39+ self . optimizer.step()
40+ self . optimizer.zero_grad()
3841
3942 # 5. "Truncate"
4043 hiddens = hiddens.detach()
4144
4245 # 6. Remove the return of `hiddens`
4346 # Returning loss in manual optimization is not needed
4447 return None
48+
49+ def configure_optimizers (self ):
50+ return optim.Adam(self .my_rnn.parameters(), lr = 0.001 )
51+
52+ if __name__ == " __main__" :
53+ model = LitModel()
54+ trainer = pl.Trainer(max_epochs = 5 )
55+ trainer.fit(model, train_dataloader) # Define your own dataloader
0 commit comments