@@ -219,3 +219,54 @@ def configure_optimizers(self):
219219
220220 def train_dataloader (self ):
221221 return DataLoader (MNIST (root = _PATH_DATASETS , train = True , download = True ), batch_size = 128 , num_workers = 1 )
222+
223+
224+ class TBPTTModule (LightningModule ):
225+ def __init__ (self ):
226+ super ().__init__ ()
227+
228+ self .batch_size = 10
229+ self .in_features = 10
230+ self .out_features = 5
231+ self .hidden_dim = 20
232+
233+ self .automatic_optimization = False
234+ self .truncated_bptt_steps = 10
235+
236+ self .rnn = nn .LSTM (self .in_features , self .hidden_dim , batch_first = True )
237+ self .linear_out = nn .Linear (in_features = self .hidden_dim , out_features = self .out_features )
238+
239+ def forward (self , x , hs ):
240+ seq , hs = self .rnn (x , hs )
241+ return self .linear_out (seq ), hs
242+
243+ def training_step (self , batch , batch_idx ):
244+ x , y = batch
245+ split_x , split_y = [
246+ x .tensor_split (self .truncated_bptt_steps , dim = 1 ),
247+ y .tensor_split (self .truncated_bptt_steps , dim = 1 ),
248+ ]
249+
250+ hiddens = None
251+ optimizer = self .optimizers ()
252+ losses = []
253+
254+ for x , y in zip (split_x , split_y ):
255+ y_pred , hiddens = self (x , hiddens )
256+ loss = F .mse_loss (y_pred , y )
257+
258+ optimizer .zero_grad ()
259+ self .manual_backward (loss )
260+ optimizer .step ()
261+
262+ # "Truncate"
263+ hiddens = [h .detach () for h in hiddens ]
264+ losses .append (loss .detach ())
265+
266+ return
267+
268+ def configure_optimizers (self ):
269+ return torch .optim .Adam (self .parameters (), lr = 0.001 )
270+
271+ def train_dataloader (self ):
272+ return DataLoader (AverageDataset (), batch_size = self .batch_size )
0 commit comments