From 4e18df142d8765248935a055235b6c4583727e60 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 6 Jan 2025 09:42:57 +0100 Subject: [PATCH 1/4] Fix TBPTT example --- docs/source-pytorch/common/tbptt.rst | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 063ef8c33d319..13ea40ba4839b 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -36,12 +36,16 @@ hidden states should be kept in-between each time-dimension split. batch_size = 10 hidden_dim = 20 hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) + # get optimizer + optimizer = self.optimizers() + for split_batch in range(split_batches): # 4. Perform the optimization in a loop loss, hiddens = self.my_rnn(split_batch, hiddens) - self.backward(loss) - self.optimizer.step() - self.optimizer.zero_grad() + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() # 5. "Truncate" hiddens = hiddens.detach() From 58f0ea623ce5c88733860868c3824c3c98de0276 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 6 Jan 2025 12:26:02 +0100 Subject: [PATCH 2/4] Make example self-contained --- docs/source-pytorch/common/tbptt.rst | 67 ++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 13ea40ba4839b..9ece2a337700c 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -12,52 +12,91 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python import torch + import torch.nn as nn + import torch.nn.functional as F import torch.optim as optim + from torch.utils.data import Dataset, DataLoader import pytorch_lightning as pl from pytorch_lightning import LightningModule + + class AverageDataset(Dataset): + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + class LitModel(LightningModule): def __init__(self): super().__init__() + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + # 1. Switch to manual optimization self.automatic_optimization = False - self.truncated_bptt_steps = 10 - self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs # 2. Remove the `hiddens` argument def training_step(self, batch, batch_idx): - # 3. Split the batch in chunks along the time dimension - split_batches = split_batch(batch, self.truncated_bptt_steps) + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1) + ] - batch_size = 10 - hidden_dim = 20 - hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) - # get optimizer + hiddens = None optimizer = self.optimizers() + losses = [] - for split_batch in range(split_batches): - # 4. Perform the optimization in a loop - loss, hiddens = self.my_rnn(split_batch, hiddens) + # 4. Perform the optimization in a loop + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) optimizer.zero_grad() self.manual_backward(loss) optimizer.step() # 5. "Truncate" - hiddens = hiddens.detach() + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + avg_loss = sum(losses) / len(losses) + self.log("train_loss", avg_loss, prog_bar=True) # 6. Remove the return of `hiddens` # Returning loss in manual optimization is not needed return None def configure_optimizers(self): - return optim.Adam(self.my_rnn.parameters(), lr=0.001) + return optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) + if __name__ == "__main__": model = LitModel() trainer = pl.Trainer(max_epochs=5) - trainer.fit(model, train_dataloader) # Define your own dataloader + trainer.fit(model) From 5878332a12c3cf380f102082896ed0a071afa47e Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 6 Jan 2025 12:30:13 +0100 Subject: [PATCH 3/4] Update imports --- docs/source-pytorch/common/tbptt.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 9ece2a337700c..04b8ea33b9235 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -16,8 +16,8 @@ hidden states should be kept in-between each time-dimension split. import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, DataLoader - import pytorch_lightning as pl - from pytorch_lightning import LightningModule + + import lightning as L class AverageDataset(Dataset): @@ -35,7 +35,7 @@ hidden states should be kept in-between each time-dimension split. return self.input_seq[item], self.output_seq[item] - class LitModel(LightningModule): + class LitModel(L.LightningModule): def __init__(self): super().__init__() @@ -98,5 +98,5 @@ hidden states should be kept in-between each time-dimension split. if __name__ == "__main__": model = LitModel() - trainer = pl.Trainer(max_epochs=5) + trainer = L.Trainer(max_epochs=5) trainer.fit(model) From 999fef60e96f52f8df83acbe5ab17b34367320e2 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 6 Jan 2025 12:40:18 +0100 Subject: [PATCH 4/4] Add test --- .../tests_pytorch/helpers/advanced_models.py | 51 +++++++++++++++++++ tests/tests_pytorch/helpers/test_models.py | 9 +++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/helpers/advanced_models.py b/tests/tests_pytorch/helpers/advanced_models.py index 4fecf516018c1..ade21004dc635 100644 --- a/tests/tests_pytorch/helpers/advanced_models.py +++ b/tests/tests_pytorch/helpers/advanced_models.py @@ -219,3 +219,54 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1) + + +class TBPTTModule(LightningModule): + def __init__(self): + super().__init__() + + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + + self.automatic_optimization = False + self.truncated_bptt_steps = 10 + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs + + def training_step(self, batch, batch_idx): + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1), + ] + + hiddens = None + optimizer = self.optimizers() + losses = [] + + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + + # "Truncate" + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + return + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 7e44f79413863..cca2fbdc2e3e0 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -17,7 +17,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN +from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel @@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class): model.to_torchscript() if data_class: model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample) + + +def test_tbptt(tmp_path): + model = TBPTTModule() + + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) + trainer.fit(model)