Skip to content

Commit ef2a826

Browse files
Alan ChuAlan Chu
authored andcommitted
address comments
1 parent bbf7c12 commit ef2a826

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

docs/source-pytorch/common/tbptt.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ hidden states should be kept in-between each time-dimension split.
1010

1111

1212
.. code-block:: python
13+
14+
import torch
1315
import torch.optim as optim
1416
import pytorch_lightning as pl
1517
from pytorch_lightning import LightningModule
@@ -31,7 +33,9 @@ hidden states should be kept in-between each time-dimension split.
3133
# 3. Split the batch in chunks along the time dimension
3234
split_batches = split_batch(batch, self.truncated_bptt_steps)
3335
34-
hiddens = ... # Choose the initial hidden state
36+
batch_size = 10
37+
hidden_dim = 20
38+
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
3539
for split_batch in range(split_batches):
3640
# 4. Perform the optimization in a loop
3741
loss, hiddens = self.my_rnn(split_batch, hiddens)

0 commit comments

Comments
 (0)