File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
docs/source-pytorch/common Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -10,6 +10,8 @@ hidden states should be kept in-between each time-dimension split.
1010
1111
1212.. code-block :: python
13+
14+ import torch
1315 import torch.optim as optim
1416 import pytorch_lightning as pl
1517 from pytorch_lightning import LightningModule
@@ -31,7 +33,9 @@ hidden states should be kept in-between each time-dimension split.
3133 # 3. Split the batch in chunks along the time dimension
3234 split_batches = split_batch(batch, self .truncated_bptt_steps)
3335
34- hiddens = ... # Choose the initial hidden state
36+ batch_size = 10
37+ hidden_dim = 20
38+ hiddens = torch.zeros(1 , batch_size, hidden_dim, device = self .device)
3539 for split_batch in range (split_batches):
3640 # 4. Perform the optimization in a loop
3741 loss, hiddens = self .my_rnn(split_batch, hiddens)
You can’t perform that action at this time.
0 commit comments