From 33bb5fe571d3910f238ca9c7e85ea09f6a9a5446 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 27 Aug 2025 07:03:21 +0200 Subject: [PATCH 01/11] add hook order --- docs/source-pytorch/common/hooks.rst | 302 +++++++++++++++++++++++++++ 1 file changed, 302 insertions(+) create mode 100644 docs/source-pytorch/common/hooks.rst diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst new file mode 100644 index 0000000000000..14ad033fcea39 --- /dev/null +++ b/docs/source-pytorch/common/hooks.rst @@ -0,0 +1,302 @@ +########################## +Hooks in PyTorch Lightning +########################## + +Hooks in Pytorch Lightning allow you to customize the training, validation, and testing logic of your models. They +provide a way to insert custom behavior at specific points during the training process without modifying the core +training loop. There are several categories of hooks available in PyTorch Lightning: + +1. **Setup/Teardown Hooks**: Called at the beginning and end of training phases +2. **Training Hooks**: Called during the training loop +3. **Validation Hooks**: Called during validation +4. **Test Hooks**: Called during testing +5. **Prediction Hooks**: Called during prediction +6. **Optimizer Hooks**: Called around optimizer operations +7. **Checkpoint Hooks**: Called during checkpoint save/load operations +8. **Exception Hooks**: Called when exceptions occur + +Nearly all hooks can be implemented in three places within your code: + +- **LightningModule**: The main module where you define your model and training logic. +- **Callbacks**: Custom classes that can be passed to the Trainer to handle specific events. +- **Strategy**: Custom strategies for distributed training. + +Importantly, because logic can be place in the same hook but in different places the call order of hooks is in +important to understand. The following order is always used: + +1. Callbacks, called in the order they are passed to the Trainer. +2. ``LightningModule`` +3. Strategy + +.. testcase:: + + from lightning.pytorch import LightningModule, Trainer + from lightning.pytorch.callbacks import Callback + + class MyModel(LightningModule): + def on_train_start(self): + print("Model: Training is starting!") + + class MyCallback(Callback): + def on_train_start(self, trainer, pl_module): + print("Callback: Training is starting!") + + model = MyModel() + callback = MyCallback() + trainer = Trainer(callbacks=[callback]) + trainer.fit(model) + # Output: + # Callback: Training is starting! + # Model: Training is starting! + +.. note:: + There are a few exceptions to this pattern: + + - **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring + callbacks + - **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and + ``LightningModule`` are called + - Some internal hooks may only call ``LightningModule`` or Strategy + +************************ +Training Loop Hook Order +************************ + +The following diagram shows the execution order of hooks during a typical training loop e.g. calling `trainer.fit()`, +with the source of each hook indicated: + +.. code-block:: text + + Training Process Flow: + + trainer.fit() + │ + ├── setup(stage="fit") + │ └── [Callbacks only] + │ + ├── on_fit_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── on_sanity_check_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ ├── on_validation_start() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ ├── on_validation_epoch_start() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ ├── [for each validation batch] + │ │ │ ├── on_validation_batch_start() + │ │ │ │ ├── [Callbacks] + │ │ │ │ ├── [LightningModule] + │ │ │ │ └── [Strategy] + │ │ │ └── on_validation_batch_end() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ └── [end validation batches] + │ ├── on_validation_epoch_end() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ └── on_validation_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + ├── on_sanity_check_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── on_train_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── [Training Epochs Loop] + │ │ + │ ├── on_train_epoch_start() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ ├── [Training Batches Loop] + │ │ │ + │ │ ├── on_train_batch_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ │ + │ │ ├── on_before_zero_grad() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ ├── [Forward Pass - training_step()] + │ │ │ └── [Strategy only] + │ │ │ + │ │ ├── on_before_backward() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ ├── [Backward Pass] + │ │ │ └── [Strategy only] + │ │ │ + │ │ ├── on_after_backward() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ ├── on_before_optimizer_step() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ ├── [Optimizer Step] + │ │ │ └── [LightningModule only - optimizer_step()] + │ │ │ + │ │ └── on_train_batch_end() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ │ [Optional: Validation during training] + │ │ ├── on_validation_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ ├── on_validation_epoch_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ │ ├── [for each validation batch] + │ │ │ │ ├── on_validation_batch_start() + │ │ │ │ │ ├── [Callbacks] + │ │ │ │ │ ├── [LightningModule] + │ │ │ │ │ └── [Strategy] + │ │ │ │ └── on_validation_batch_end() + │ │ │ │ ├── [Callbacks] + │ │ │ │ ├── [LightningModule] + │ │ │ │ └── [Strategy] + │ │ │ └── [end validation batches] + │ │ ├── on_validation_epoch_end() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ └── on_validation_end() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ + │ └── on_train_epoch_end() **SPECIAL CASE** + │ ├── [Callbacks - Non-monitoring only] + │ ├── [LightningModule] + │ └── [Callbacks - Monitoring only] + │ + ├── [End Training Epochs] + │ + ├── on_train_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── on_fit_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + └── teardown(stage="fit") + └── [Callbacks only] + +*********************** +Testing Loop Hook Order +*********************** + +When running tests with ``trainer.test()``: + +.. code-block:: text + + trainer.test() + │ + ├── setup(stage="test") + │ └── [Callbacks only] + ├── on_test_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── [Test Epochs Loop] + │ │ + │ ├── on_test_epoch_start() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ + │ ├── [Test Batches Loop] + │ │ │ + │ │ ├── on_test_batch_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ │ + │ │ └── on_test_batch_end() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ + │ └── on_test_epoch_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── on_test_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + └── teardown(stage="test") + └── [Callbacks only] + +************************** +Prediction Loop Hook Order +************************** + +When running predictions with ``trainer.predict()``: + +.. code-block:: text + + trainer.predict() + │ + ├── setup(stage="predict") + │ └── [Callbacks only] + ├── on_predict_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── [Prediction Epochs Loop] + │ │ + │ ├── on_predict_epoch_start() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ ├── [Prediction Batches Loop] + │ │ │ + │ │ ├── on_predict_batch_start() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ └── on_predict_batch_end() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ └── on_predict_epoch_end() + │ ├── [Callbacks] + │ └── [LightningModule] + │ + ├── on_predict_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + └── teardown(stage="predict") + └── [Callbacks only] From 1ffa4bfa752bbedc4354c27933dcc74de099d864 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 27 Aug 2025 07:03:32 +0200 Subject: [PATCH 02/11] add to index --- docs/source-pytorch/glossary/index.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst index 5f06d3a7ea3a3..349d2d0326d0c 100644 --- a/docs/source-pytorch/glossary/index.rst +++ b/docs/source-pytorch/glossary/index.rst @@ -20,6 +20,7 @@ FSDP <../advanced/model_parallel/fsdp> GPU <../accelerators/gpu> Half precision <../common/precision> + Hooks <../common/hooks> HPU <../integrations/hpu/index> Inference <../deploy/production_intermediate> Lightning CLI <../cli/lightning_cli> @@ -179,6 +180,13 @@ Glossary :button_link: ../common/precision.html :height: 100 +.. displayitem:: + :header: Hooks + :description: How to customize the training, validation, and testing logic + :col_css: col-md-12 + :button_link: ../common/hooks.html + :height: 100 + .. displayitem:: :header: HPU :description: Habana Gaudi AI Processor Unit for faster training From f02c944aae90a2f4a269a83eee0f1ffd0065550f Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 27 Aug 2025 12:04:10 +0200 Subject: [PATCH 03/11] Apply suggestions from code review --- docs/source-pytorch/common/hooks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 14ad033fcea39..0c9022cf9a523 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -28,7 +28,7 @@ important to understand. The following order is always used: 2. ``LightningModule`` 3. Strategy -.. testcase:: +.. testcode:: from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback From 945c6e117904005ef8922fba62ac7801e69476b9 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 12:52:57 +0200 Subject: [PATCH 04/11] fixed --- docs/source-pytorch/common/hooks.rst | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 0c9022cf9a523..be457e2c4255e 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -52,10 +52,8 @@ important to understand. The following order is always used: .. note:: There are a few exceptions to this pattern: - - **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring - callbacks - - **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and - ``LightningModule`` are called + - **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring callbacks + - **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and ``LightningModule`` are called - Some internal hooks may only call ``LightningModule`` or Strategy ************************ From dba5219039fff218c18cc7e8f847e728415c1d0d Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 13:03:33 +0200 Subject: [PATCH 05/11] BoringModel --- docs/source-pytorch/common/hooks.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index be457e2c4255e..b4e2a3d692cf6 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -30,10 +30,11 @@ important to understand. The following order is always used: .. testcode:: - from lightning.pytorch import LightningModule, Trainer + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback + from lightning.pytorch.demos BoringModel - class MyModel(LightningModule): + class MyModel(BoringModel): def on_train_start(self): print("Model: Training is starting!") From faecbfffb3bc612396f04f63265504f0d1547157 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 13:12:45 +0200 Subject: [PATCH 06/11] typo --- docs/source-pytorch/common/hooks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index b4e2a3d692cf6..2e65574d5c1c4 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -32,7 +32,7 @@ important to understand. The following order is always used: from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback - from lightning.pytorch.demos BoringModel + from lightning.pytorch.demos import BoringModel class MyModel(BoringModel): def on_train_start(self): From 39de433f9d32f73257e0a70ed9cc5fb2168349da Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 13:19:42 +0200 Subject: [PATCH 07/11] params --- docs/source-pytorch/common/hooks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 2e65574d5c1c4..956b43b882ce3 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -44,7 +44,7 @@ important to understand. The following order is always used: model = MyModel() callback = MyCallback() - trainer = Trainer(callbacks=[callback]) + trainer = Trainer(callbacks=[callback], logger=False, max_epochs=1) trainer.fit(model) # Output: # Callback: Training is starting! From a79c134d715e5ef09b0aa4a47b55d061a5494324 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 13:31:48 +0200 Subject: [PATCH 08/11] testoutput --- docs/source-pytorch/common/hooks.rst | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 956b43b882ce3..421a30233550a 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -46,9 +46,26 @@ important to understand. The following order is always used: callback = MyCallback() trainer = Trainer(callbacks=[callback], logger=False, max_epochs=1) trainer.fit(model) - # Output: - # Callback: Training is starting! - # Model: Training is starting! + +.. testoutput:: + :hide: + :options: -ELLIPSIS, +NORMALIZE_WHITESPACE + ┏━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓ + ┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃ + ┡━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩ + │ 0 │ layer │ Linear │ 66 │ train │ 0 │ + └───┴───────┴────────┴────────┴───────┴───────┘ + Trainable params: 66 + Non-trainable params: 0 + Total params: 66 + Total estimated model params size (MB): 0 + Modules in train mode: 1 + Modules in eval mode: 0 + Total FLOPs: 0 + Callback: Training is starting! + Model: Training is starting! + Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ... + .. note:: There are a few exceptions to this pattern: From 8c5e264dc665cfe90ec2c80dfa0e7af9b6cedb63 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 13:40:10 +0200 Subject: [PATCH 09/11] space --- docs/source-pytorch/common/hooks.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 421a30233550a..933600318fe68 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -50,6 +50,7 @@ important to understand. The following order is always used: .. testoutput:: :hide: :options: -ELLIPSIS, +NORMALIZE_WHITESPACE + ┏━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃ ┡━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩ From 0477fe81ae2dabb8d48475da885fef4e7f715f6f Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 14:07:34 +0200 Subject: [PATCH 10/11] ELLIPSIS --- docs/source-pytorch/common/hooks.rst | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 933600318fe68..7be21a3f67033 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -51,21 +51,10 @@ important to understand. The following order is always used: :hide: :options: -ELLIPSIS, +NORMALIZE_WHITESPACE - ┏━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓ - ┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃ - ┡━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩ - │ 0 │ layer │ Linear │ 66 │ train │ 0 │ - └───┴───────┴────────┴────────┴───────┴───────┘ - Trainable params: 66 - Non-trainable params: 0 - Total params: 66 - Total estimated model params size (MB): 0 - Modules in train mode: 1 - Modules in eval mode: 0 - Total FLOPs: 0 + ... Callback: Training is starting! Model: Training is starting! - Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ... + Epoch 0/0 ... • ... .. note:: From acc65379525a3ac7213428e415740c42d340d832 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 1 Sep 2025 14:08:42 +0200 Subject: [PATCH 11/11] +ELLIPSIS --- docs/source-pytorch/common/hooks.rst | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 7be21a3f67033..79b1118a0dcf4 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -49,12 +49,17 @@ important to understand. The following order is always used: .. testoutput:: :hide: - :options: -ELLIPSIS, +NORMALIZE_WHITESPACE + :options: +ELLIPSIS, +NORMALIZE_WHITESPACE + ┏━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓ + ┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃ + ┡━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩ + │ 0 │ layer │ Linear │ 66 │ train │ 0 │ + └───┴───────┴────────┴────────┴───────┴───────┘ ... Callback: Training is starting! Model: Training is starting! - Epoch 0/0 ... • ... + Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ... .. note::