Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 64 additions & 21 deletions docs/source-pytorch/common/tbptt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
.. code-block:: python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset, DataLoader

class LitModel(LightningModule):
import lightning as L


class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)

def __len__(self):
return self.dataset_len

def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]


class LitModel(L.LightningModule):

def __init__(self):
super().__init__()

self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20

# 1. Switch to manual optimization
self.automatic_optimization = False

self.truncated_bptt_steps = 10
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN

self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)

def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs

# 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx):

# 3. Split the batch in chunks along the time dimension
split_batches = split_batch(batch, self.truncated_bptt_steps)

batch_size = 10
hidden_dim = 20
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
for split_batch in range(split_batches):
# 4. Perform the optimization in a loop
loss, hiddens = self.my_rnn(split_batch, hiddens)
self.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1)
]

hiddens = None
optimizer = self.optimizers()
losses = []

# 4. Perform the optimization in a loop
for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)

optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()

# 5. "Truncate"
hiddens = hiddens.detach()
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())

avg_loss = sum(losses) / len(losses)
self.log("train_loss", avg_loss, prog_bar=True)

# 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed
return None

def configure_optimizers(self):
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
return optim.Adam(self.parameters(), lr=0.001)

def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)


if __name__ == "__main__":
model = LitModel()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader) # Define your own dataloader
trainer = L.Trainer(max_epochs=5)
trainer.fit(model)
51 changes: 51 additions & 0 deletions tests/tests_pytorch/helpers/advanced_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,54 @@ def configure_optimizers(self):

def train_dataloader(self):
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)


class TBPTTModule(LightningModule):
def __init__(self):
super().__init__()

self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20

self.automatic_optimization = False
self.truncated_bptt_steps = 10

self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)

def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs

def training_step(self, batch, batch_idx):
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1),
]

hiddens = None
optimizer = self.optimizers()
losses = []

for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)

optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()

# "Truncate"
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())

return

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)

def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)
9 changes: 8 additions & 1 deletion tests/tests_pytorch/helpers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel

from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
Expand Down Expand Up @@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
model.to_torchscript()
if data_class:
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)


def test_tbptt(tmp_path):
model = TBPTTModule()

trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
trainer.fit(model)
Loading