Skip to content

Commit 5262534

Browse files
Alan ChuAlan Chu
authored andcommitted
make example easily copy and runnable
1 parent 6f7f206 commit 5262534

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

docs/source-pytorch/common/tbptt.rst

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ hidden states should be kept in-between each time-dimension split.
1010

1111

1212
.. code-block:: python
13+
import torch.optim as optim
14+
import pytorch_lightning as pl
15+
from pytorch_lightning import LightningModule
1316
1417
class LitModel(LightningModule):
1518
@@ -20,25 +23,33 @@ hidden states should be kept in-between each time-dimension split.
2023
self.automatic_optimization = False
2124
2225
self.truncated_bptt_steps = 10
23-
self.my_rnn = ...
26+
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
2427
2528
# 2. Remove the `hiddens` argument
2629
def training_step(self, batch, batch_idx):
2730
2831
# 3. Split the batch in chunks along the time dimension
2932
split_batches = split_batch(batch, self.truncated_bptt_steps)
3033
31-
hiddens = ... # 3. Choose the initial hidden state
34+
hiddens = ... # Choose the initial hidden state
3235
for split_batch in range(split_batches):
3336
# 4. Perform the optimization in a loop
3437
loss, hiddens = self.my_rnn(split_batch, hiddens)
3538
self.backward(loss)
36-
optimizer.step()
37-
optimizer.zero_grad()
39+
self.optimizer.step()
40+
self.optimizer.zero_grad()
3841
3942
# 5. "Truncate"
4043
hiddens = hiddens.detach()
4144
4245
# 6. Remove the return of `hiddens`
4346
# Returning loss in manual optimization is not needed
4447
return None
48+
49+
def configure_optimizers(self):
50+
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
51+
52+
if __name__ == "__main__":
53+
model = LitModel()
54+
trainer = pl.Trainer(max_epochs=5)
55+
trainer.fit(model, train_dataloader) # Define your own dataloader

0 commit comments

Comments
 (0)