From 1fc2cfaa509dc70cb491e8b3cf74e6c57502945f Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Fri, 15 Nov 2024 01:37:29 -0800 Subject: [PATCH 01/13] Add doc for TBPTT --- .../common/lightning_module.rst | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 15e3af75d7aec..2fc6aedeb2273 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1052,6 +1052,56 @@ Set and access example_input_array, which basically represents a single batch. # generate some images using the example_input_array gen_images = self.generator(self.example_input_array) +truncated_bptt_steps +~~~~~~~~~~~~~~~~~~~~ + +Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of +a much longer sequence. This is made possible by passing training batches +split along the time-dimensions into splits of size k to the +``training_step``. In order to keep the same forward propagation behavior, all +hidden states should be kept in-between each time-dimension split. + +(`Williams et al. "An efficient gradient-based algorithm for on-line training of +recurrent network trajectories." +`_) + +`Tutorial `_ + +.. code-block:: python + import lightning as L + + class LitModel(L.LightningModule): + + def __init__(self): + super().__init__() + + # 1. Switch to manual optimization + self.automatic_optimization = False + + self.truncated_bptt_steps = 10 + self.my_rnn = ... + + # 2. Remove the `hiddens` argument + def training_step(self, batch, batch_idx): + + # 3. Split the batch in chunks along the time dimension + split_batches = split_batch(batch, self.truncated_bptt_steps) + + hiddens = ... # 3. Choose the initial hidden state + for split_batch in range(split_batches): + # 4. Perform the optimization in a loop + loss, hiddens = self.my_rnn(split_batch, hiddens) + self.backward(loss) + optimizer.step() + optimizer.zero_grad() + + # 5. "Truncate" + hiddens = hiddens.detach() + + # 6. Remove the return of `hiddens` + # Returning loss in manual optimization is not needed + return None + -------------- .. _lightning_hooks: From efae604f4097a30bf8f154d2532469197a5ecf5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Nov 2024 09:39:37 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/common/lightning_module.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 2fc6aedeb2273..0b045081730d0 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1074,19 +1074,19 @@ recurrent network trajectories." def __init__(self): super().__init__() - + # 1. Switch to manual optimization self.automatic_optimization = False - + self.truncated_bptt_steps = 10 self.my_rnn = ... - + # 2. Remove the `hiddens` argument def training_step(self, batch, batch_idx): - + # 3. Split the batch in chunks along the time dimension split_batches = split_batch(batch, self.truncated_bptt_steps) - + hiddens = ... # 3. Choose the initial hidden state for split_batch in range(split_batches): # 4. Perform the optimization in a loop @@ -1094,10 +1094,10 @@ recurrent network trajectories." self.backward(loss) optimizer.step() optimizer.zero_grad() - + # 5. "Truncate" hiddens = hiddens.detach() - + # 6. Remove the return of `hiddens` # Returning loss in manual optimization is not needed return None From b296ec0f095940aed365d7c64c985536313a29ef Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Fri, 15 Nov 2024 22:06:29 -0800 Subject: [PATCH 03/13] remove url to prevent linting error --- docs/source-pytorch/common/lightning_module.rst | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 0b045081730d0..0f80e7d72c86f 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1061,11 +1061,6 @@ split along the time-dimensions into splits of size k to the ``training_step``. In order to keep the same forward propagation behavior, all hidden states should be kept in-between each time-dimension split. -(`Williams et al. "An efficient gradient-based algorithm for on-line training of -recurrent network trajectories." -`_) - -`Tutorial `_ .. code-block:: python import lightning as L From 7f9673dc27a18e57416d39ae55b6eba8db88e419 Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Sat, 16 Nov 2024 04:45:21 -0800 Subject: [PATCH 04/13] attempt to fix linter --- docs/source-pytorch/common/lightning_module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 0f80e7d72c86f..bc664d2f1194f 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1062,7 +1062,7 @@ split along the time-dimensions into splits of size k to the hidden states should be kept in-between each time-dimension split. -.. code-block:: python +.. testcode:: python import lightning as L class LitModel(L.LightningModule): From 54c89a8add8c38701404c15867a5077901932d2e Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Mon, 25 Nov 2024 18:37:37 +0000 Subject: [PATCH 05/13] add tbptt.rst file --- .../common/lightning_module.rst | 45 ------------------- docs/source-pytorch/common/tbptt.rst | 44 ++++++++++++++++++ 2 files changed, 44 insertions(+), 45 deletions(-) create mode 100644 docs/source-pytorch/common/tbptt.rst diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index bc664d2f1194f..15e3af75d7aec 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1052,51 +1052,6 @@ Set and access example_input_array, which basically represents a single batch. # generate some images using the example_input_array gen_images = self.generator(self.example_input_array) -truncated_bptt_steps -~~~~~~~~~~~~~~~~~~~~ - -Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of -a much longer sequence. This is made possible by passing training batches -split along the time-dimensions into splits of size k to the -``training_step``. In order to keep the same forward propagation behavior, all -hidden states should be kept in-between each time-dimension split. - - -.. testcode:: python - import lightning as L - - class LitModel(L.LightningModule): - - def __init__(self): - super().__init__() - - # 1. Switch to manual optimization - self.automatic_optimization = False - - self.truncated_bptt_steps = 10 - self.my_rnn = ... - - # 2. Remove the `hiddens` argument - def training_step(self, batch, batch_idx): - - # 3. Split the batch in chunks along the time dimension - split_batches = split_batch(batch, self.truncated_bptt_steps) - - hiddens = ... # 3. Choose the initial hidden state - for split_batch in range(split_batches): - # 4. Perform the optimization in a loop - loss, hiddens = self.my_rnn(split_batch, hiddens) - self.backward(loss) - optimizer.step() - optimizer.zero_grad() - - # 5. "Truncate" - hiddens = hiddens.detach() - - # 6. Remove the return of `hiddens` - # Returning loss in manual optimization is not needed - return None - -------------- .. _lightning_hooks: diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst new file mode 100644 index 0000000000000..88b6055b19fe6 --- /dev/null +++ b/docs/source-pytorch/common/tbptt.rst @@ -0,0 +1,44 @@ +truncated_bptt_steps +~~~~~~~~~~~~~~~~~~~~ + +Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of +a much longer sequence. This is made possible by passing training batches +split along the time-dimensions into splits of size k to the +``training_step``. In order to keep the same forward propagation behavior, all +hidden states should be kept in-between each time-dimension split. + + +.. testcode:: python + import lightning as L + + class LitModel(L.LightningModule): + + def __init__(self): + super().__init__() + + # 1. Switch to manual optimization + self.automatic_optimization = False + + self.truncated_bptt_steps = 10 + self.my_rnn = ... + + # 2. Remove the `hiddens` argument + def training_step(self, batch, batch_idx): + + # 3. Split the batch in chunks along the time dimension + split_batches = split_batch(batch, self.truncated_bptt_steps) + + hiddens = ... # 3. Choose the initial hidden state + for split_batch in range(split_batches): + # 4. Perform the optimization in a loop + loss, hiddens = self.my_rnn(split_batch, hiddens) + self.backward(loss) + optimizer.step() + optimizer.zero_grad() + + # 5. "Truncate" + hiddens = hiddens.detach() + + # 6. Remove the return of `hiddens` + # Returning loss in manual optimization is not needed + return None \ No newline at end of file From 5c367c55ef71a4e442e4bd427dc068d8d61cf92b Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Mon, 25 Nov 2024 19:11:34 +0000 Subject: [PATCH 06/13] adjust doc: --- docs/source-pytorch/common/tbptt.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 88b6055b19fe6..70791b2e7008b 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -1,5 +1,6 @@ -truncated_bptt_steps -~~~~~~~~~~~~~~~~~~~~ +############## +Truncated Backpropagation Through Time (TBPTT) +############## Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of a much longer sequence. This is made possible by passing training batches @@ -8,7 +9,7 @@ split along the time-dimensions into splits of size k to the hidden states should be kept in-between each time-dimension split. -.. testcode:: python +.. code-block:: python import lightning as L class LitModel(L.LightningModule): From 610809c5c8e1f08a3887774fd25da3b6d7c1e6e7 Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Mon, 25 Nov 2024 19:28:07 +0000 Subject: [PATCH 07/13] nit --- docs/source-pytorch/common/tbptt.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 70791b2e7008b..b98499bdb0443 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -10,9 +10,8 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python - import lightning as L - - class LitModel(L.LightningModule): + + class LitModel(LightningModule): def __init__(self): super().__init__() From 48dbdd29a905e7a74c3dbe19b55589de8f7b26c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 19:28:42 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/common/tbptt.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index b98499bdb0443..6c23523033ef5 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -10,7 +10,7 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python - + class LitModel(LightningModule): def __init__(self): @@ -41,4 +41,4 @@ hidden states should be kept in-between each time-dimension split. # 6. Remove the return of `hiddens` # Returning loss in manual optimization is not needed - return None \ No newline at end of file + return None From 5262534be63c59d6860e6ac40f35cc1f5a31872d Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Tue, 26 Nov 2024 01:37:55 +0000 Subject: [PATCH 09/13] make example easily copy and runnable --- docs/source-pytorch/common/tbptt.rst | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 6c23523033ef5..7b5531da98fbc 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -10,6 +10,9 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python + import torch.optim as optim + import pytorch_lightning as pl + from pytorch_lightning import LightningModule class LitModel(LightningModule): @@ -20,7 +23,7 @@ hidden states should be kept in-between each time-dimension split. self.automatic_optimization = False self.truncated_bptt_steps = 10 - self.my_rnn = ... + self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN # 2. Remove the `hiddens` argument def training_step(self, batch, batch_idx): @@ -28,13 +31,13 @@ hidden states should be kept in-between each time-dimension split. # 3. Split the batch in chunks along the time dimension split_batches = split_batch(batch, self.truncated_bptt_steps) - hiddens = ... # 3. Choose the initial hidden state + hiddens = ... # Choose the initial hidden state for split_batch in range(split_batches): # 4. Perform the optimization in a loop loss, hiddens = self.my_rnn(split_batch, hiddens) self.backward(loss) - optimizer.step() - optimizer.zero_grad() + self.optimizer.step() + self.optimizer.zero_grad() # 5. "Truncate" hiddens = hiddens.detach() @@ -42,3 +45,11 @@ hidden states should be kept in-between each time-dimension split. # 6. Remove the return of `hiddens` # Returning loss in manual optimization is not needed return None + + def configure_optimizers(self): + return optim.Adam(self.my_rnn.parameters(), lr=0.001) + + if __name__ == "__main__": + model = LitModel() + trainer = pl.Trainer(max_epochs=5) + trainer.fit(model, train_dataloader) # Define your own dataloader \ No newline at end of file From ce9790e62592d15ae7e899e1bdbf4d2bad119a64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Nov 2024 01:38:58 +0000 Subject: [PATCH 10/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/common/tbptt.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 7b5531da98fbc..91cd72126413b 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -52,4 +52,4 @@ hidden states should be kept in-between each time-dimension split. if __name__ == "__main__": model = LitModel() trainer = pl.Trainer(max_epochs=5) - trainer.fit(model, train_dataloader) # Define your own dataloader \ No newline at end of file + trainer.fit(model, train_dataloader) # Define your own dataloader From ef2a826a8285571681985d1c3c197846fe93e500 Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Tue, 26 Nov 2024 15:50:27 +0000 Subject: [PATCH 11/13] address comments --- docs/source-pytorch/common/tbptt.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 91cd72126413b..6b2ec1bdc6262 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -10,6 +10,8 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python + + import torch import torch.optim as optim import pytorch_lightning as pl from pytorch_lightning import LightningModule @@ -31,7 +33,9 @@ hidden states should be kept in-between each time-dimension split. # 3. Split the batch in chunks along the time dimension split_batches = split_batch(batch, self.truncated_bptt_steps) - hiddens = ... # Choose the initial hidden state + batch_size = 10 + hidden_dim = 20 + hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) for split_batch in range(split_batches): # 4. Perform the optimization in a loop loss, hiddens = self.my_rnn(split_batch, hiddens) From 8aee16a553f6dfb2414efd64f0c12a1f1a600821 Mon Sep 17 00:00:00 2001 From: Alan Chu Date: Mon, 9 Dec 2024 18:07:39 +0000 Subject: [PATCH 12/13] fix doc test warning --- docs/source-pytorch/common/index.rst | 7 +++++++ docs/source-pytorch/common/tbptt.rst | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/index.rst b/docs/source-pytorch/common/index.rst index 738e971aec532..42f7adcc2ed24 100644 --- a/docs/source-pytorch/common/index.rst +++ b/docs/source-pytorch/common/index.rst @@ -202,6 +202,13 @@ How-to Guides :col_css: col-md-4 :height: 180 +.. displayitem:: + :header: Truncated Back-Propagation Through Time + :description: Efficiently step through time when training recurrent models + :button_link: ../common/tbtt.html + :col_css: col-md-4 + :height: 180 + .. raw:: html diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 6b2ec1bdc6262..6a39f8ec4759a 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -1,6 +1,6 @@ -############## +############################################## Truncated Backpropagation Through Time (TBPTT) -############## +############################################## Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of a much longer sequence. This is made possible by passing training batches From 0d2a38415571377f3ab3785d5ad73c32f764b2ca Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 10 Dec 2024 10:32:27 +0100 Subject: [PATCH 13/13] Update docs/source-pytorch/common/tbptt.rst --- docs/source-pytorch/common/tbptt.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 6a39f8ec4759a..063ef8c33d319 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -2,7 +2,7 @@ Truncated Backpropagation Through Time (TBPTT) ############################################## -Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of +Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of a much longer sequence. This is made possible by passing training batches split along the time-dimensions into splits of size k to the ``training_step``. In order to keep the same forward propagation behavior, all