Skip to content
313 changes: 313 additions & 0 deletions docs/source-pytorch/common/hooks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
##########################
Hooks in PyTorch Lightning
##########################

Hooks in Pytorch Lightning allow you to customize the training, validation, and testing logic of your models. They
provide a way to insert custom behavior at specific points during the training process without modifying the core
training loop. There are several categories of hooks available in PyTorch Lightning:

1. **Setup/Teardown Hooks**: Called at the beginning and end of training phases
2. **Training Hooks**: Called during the training loop
3. **Validation Hooks**: Called during validation
4. **Test Hooks**: Called during testing
5. **Prediction Hooks**: Called during prediction
6. **Optimizer Hooks**: Called around optimizer operations
7. **Checkpoint Hooks**: Called during checkpoint save/load operations
8. **Exception Hooks**: Called when exceptions occur

Nearly all hooks can be implemented in three places within your code:

- **LightningModule**: The main module where you define your model and training logic.
- **Callbacks**: Custom classes that can be passed to the Trainer to handle specific events.
- **Strategy**: Custom strategies for distributed training.

Importantly, because logic can be place in the same hook but in different places the call order of hooks is in
important to understand. The following order is always used:

1. Callbacks, called in the order they are passed to the Trainer.
2. ``LightningModule``
3. Strategy

.. testcode::

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.demos import BoringModel

class MyModel(BoringModel):
def on_train_start(self):
print("Model: Training is starting!")

class MyCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Callback: Training is starting!")

model = MyModel()
callback = MyCallback()
trainer = Trainer(callbacks=[callback], logger=False, max_epochs=1)
trainer.fit(model)

.. testoutput::
:hide:
:options: +ELLIPSIS, +NORMALIZE_WHITESPACE

┏━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃
┡━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ layer │ Linear │ 66 │ train │ 0 │
└───┴───────┴────────┴────────┴───────┴───────┘
...
Callback: Training is starting!
Model: Training is starting!
Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ...


.. note::
There are a few exceptions to this pattern:

- **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring callbacks
- **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and ``LightningModule`` are called
- Some internal hooks may only call ``LightningModule`` or Strategy

************************
Training Loop Hook Order
************************

The following diagram shows the execution order of hooks during a typical training loop e.g. calling `trainer.fit()`,
with the source of each hook indicated:

.. code-block:: text

Training Process Flow:

trainer.fit()
├── setup(stage="fit")
│ └── [Callbacks only]
Comment on lines +83 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a suggestion. perhaps adding this could also be helpful? as it also shows one of the known issue #19658

    ├── setup(stage="fit")
    │   └── [Callback.setup]
    │   └── [LightnintModule.configure_shared_model]
    │   └── [LightnintModule.configure_model]
    │   └── [Strategy.setup]

├── on_fit_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── on_sanity_check_start()
Comment on lines +91 to +93
Copy link
Contributor

@GdoongMathew GdoongMathew Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we could also add when the checkpoint and state_dict are load?

on_fit_start()

if strategy.restore_checkpoint_after_setup:
    [LightningModule.on_load_checkpoint]
    [Strategy.load_model_state_dict()

[optimizer.load_state_dict]
[lr_scheduler.load_state_dict]

on_sanity_check_start()

│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│ ├── on_validation_start()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ ├── on_validation_epoch_start()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │ ├── [for each validation batch]
│ │ │ ├── on_validation_batch_start()
│ │ │ │ ├── [Callbacks]
│ │ │ │ ├── [LightningModule]
│ │ │ │ └── [Strategy]
│ │ │ └── on_validation_batch_end()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ └── [end validation batches]
│ ├── on_validation_epoch_end()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ └── on_validation_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── on_sanity_check_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── on_train_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── [Training Epochs Loop]
│ │
│ ├── on_train_epoch_start()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ ├── [Training Batches Loop]
│ │ │
│ │ ├── on_train_batch_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ │
│ │ ├── on_before_zero_grad()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ ├── [Forward Pass - training_step()]
│ │ │ └── [Strategy only]
Comment on lines +146 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the training step is called before the on_before_zero_grad is called?

strategy.training_step()
    | -- LightningModule.training_step()
on_before_zero_grad()
    |-- Callback.on_before_zero_grad()
    |-- LightningModule.on_before_zero_grad()
    |-- LightningModule.optimizer_zero_grad()

│ │ │
│ │ ├── on_before_backward()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ ├── [Backward Pass]
│ │ │ └── [Strategy only]
│ │ │
│ │ ├── on_after_backward()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ ├── on_before_optimizer_step()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
Comment on lines +164 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps also add configure_gradient_clipping here?

    │   │   ├── on_before_optimizer_step()
    │   │   │   ├── [Callbacks]
    │   │   │   ├── [LightningModule]
    │   │   │   └── [LightningModule.configure_gradient_clipping]

│ │ │
│ │ ├── [Optimizer Step]
│ │ │ └── [LightningModule only - optimizer_step()]
│ │ │
│ │ └── on_train_batch_end()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ │ [Optional: Validation during training]
│ │ ├── on_validation_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ ├── on_validation_epoch_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ │ ├── [for each validation batch]
│ │ │ │ ├── on_validation_batch_start()
│ │ │ │ │ ├── [Callbacks]
│ │ │ │ │ ├── [LightningModule]
│ │ │ │ │ └── [Strategy]
│ │ │ │ └── on_validation_batch_end()
│ │ │ │ ├── [Callbacks]
│ │ │ │ ├── [LightningModule]
│ │ │ │ └── [Strategy]
│ │ │ └── [end validation batches]
│ │ ├── on_validation_epoch_end()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ └── on_validation_end()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │
│ └── on_train_epoch_end() **SPECIAL CASE**
│ ├── [Callbacks - Non-monitoring only]
│ ├── [LightningModule]
│ └── [Callbacks - Monitoring only]
Comment on lines +204 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps Non-monitoring could be changed to Checkpoint & EarlyStopping to be more specific.

├── [End Training Epochs]
├── on_train_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── on_fit_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
└── teardown(stage="fit")
Comment on lines +215 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from my personal code inspect, I believe that the strategy.teardown should be called before on_fit_end stage, and also the strategy class does not have a on_fit_end callback.

strategy.teardown()
     |--LightningModule.cpu()
on_fit_end()
     |--Callback.on_fit_end()
     |--LightningModule.on_fit_end()
Callback.teardown()
LightningModule.teardown()

└── [Callbacks only]
Comment on lines +220 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that the lightning module also has a teardown callback here.

    └── teardown(stage="fit")
        └── [Callback]
        └── [LightningModule]


***********************
Testing Loop Hook Order
***********************

When running tests with ``trainer.test()``:

.. code-block:: text

trainer.test()
├── setup(stage="test")
│ └── [Callbacks only]
├── on_test_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── [Test Epochs Loop]
│ │
│ ├── on_test_epoch_start()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │
│ ├── [Test Batches Loop]
│ │ │
│ │ ├── on_test_batch_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ │
│ │ └── on_test_batch_end()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │
│ └── on_test_epoch_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── on_test_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
└── teardown(stage="test")
└── [Callbacks only]

**************************
Prediction Loop Hook Order
**************************

When running predictions with ``trainer.predict()``:

.. code-block:: text

trainer.predict()
├── setup(stage="predict")
│ └── [Callbacks only]
├── on_predict_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── [Prediction Epochs Loop]
│ │
│ ├── on_predict_epoch_start()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ ├── [Prediction Batches Loop]
│ │ │
│ │ ├── on_predict_batch_start()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ └── on_predict_batch_end()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ └── on_predict_epoch_end()
│ ├── [Callbacks]
│ └── [LightningModule]
├── on_predict_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
└── teardown(stage="predict")
└── [Callbacks only]
8 changes: 8 additions & 0 deletions docs/source-pytorch/glossary/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FSDP <../advanced/model_parallel/fsdp>
GPU <../accelerators/gpu>
Half precision <../common/precision>
Hooks <../common/hooks>
HPU <../integrations/hpu/index>
Inference <../deploy/production_intermediate>
Lightning CLI <../cli/lightning_cli>
Expand Down Expand Up @@ -179,6 +180,13 @@ Glossary
:button_link: ../common/precision.html
:height: 100

.. displayitem::
:header: Hooks
:description: How to customize the training, validation, and testing logic
:col_css: col-md-12
:button_link: ../common/hooks.html
:height: 100

.. displayitem::
:header: HPU
:description: Habana Gaudi AI Processor Unit for faster training
Expand Down