-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Docs on hook call order #21120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Docs on hook call order #21120
Changes from 3 commits
33bb5fe
1ffa4bf
662f671
f02c944
0148033
0c2ae6b
ab19e4d
945c6e1
dba5219
faecbff
39de433
a79c134
8c5e264
0477fe8
acc6537
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
########################## | ||
Hooks in PyTorch Lightning | ||
########################## | ||
|
||
Hooks in Pytorch Lightning allow you to customize the training, validation, and testing logic of your models. They | ||
provide a way to insert custom behavior at specific points during the training process without modifying the core | ||
training loop. There are several categories of hooks available in PyTorch Lightning: | ||
|
||
1. **Setup/Teardown Hooks**: Called at the beginning and end of training phases | ||
2. **Training Hooks**: Called during the training loop | ||
3. **Validation Hooks**: Called during validation | ||
4. **Test Hooks**: Called during testing | ||
5. **Prediction Hooks**: Called during prediction | ||
6. **Optimizer Hooks**: Called around optimizer operations | ||
7. **Checkpoint Hooks**: Called during checkpoint save/load operations | ||
8. **Exception Hooks**: Called when exceptions occur | ||
|
||
Nearly all hooks can be implemented in three places within your code: | ||
|
||
- **LightningModule**: The main module where you define your model and training logic. | ||
- **Callbacks**: Custom classes that can be passed to the Trainer to handle specific events. | ||
- **Strategy**: Custom strategies for distributed training. | ||
|
||
Importantly, because logic can be place in the same hook but in different places the call order of hooks is in | ||
important to understand. The following order is always used: | ||
|
||
1. Callbacks, called in the order they are passed to the Trainer. | ||
2. ``LightningModule`` | ||
3. Strategy | ||
|
||
.. testcase:: | ||
|
||
from lightning.pytorch import LightningModule, Trainer | ||
from lightning.pytorch.callbacks import Callback | ||
|
||
class MyModel(LightningModule): | ||
def on_train_start(self): | ||
print("Model: Training is starting!") | ||
|
||
class MyCallback(Callback): | ||
def on_train_start(self, trainer, pl_module): | ||
print("Callback: Training is starting!") | ||
|
||
model = MyModel() | ||
callback = MyCallback() | ||
trainer = Trainer(callbacks=[callback]) | ||
trainer.fit(model) | ||
# Output: | ||
# Callback: Training is starting! | ||
# Model: Training is starting! | ||
|
||
.. note:: | ||
There are a few exceptions to this pattern: | ||
|
||
- **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring | ||
callbacks | ||
- **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and | ||
``LightningModule`` are called | ||
- Some internal hooks may only call ``LightningModule`` or Strategy | ||
|
||
************************ | ||
Training Loop Hook Order | ||
************************ | ||
|
||
The following diagram shows the execution order of hooks during a typical training loop e.g. calling `trainer.fit()`, | ||
with the source of each hook indicated: | ||
|
||
.. code-block:: text | ||
|
||
Training Process Flow: | ||
|
||
trainer.fit() | ||
│ | ||
├── setup(stage="fit") | ||
│ └── [Callbacks only] | ||
Comment on lines
+83
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just a suggestion. perhaps adding this could also be helpful? as it also shows one of the known issue #19658
|
||
│ | ||
├── on_fit_start() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
├── on_sanity_check_start() | ||
Comment on lines
+91
to
+93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps we could also add when the checkpoint and state_dict are load?
|
||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ ├── on_validation_start() | ||
│ │ ├── [Callbacks] | ||
│ │ ├── [LightningModule] | ||
│ │ └── [Strategy] | ||
│ ├── on_validation_epoch_start() | ||
│ │ ├── [Callbacks] | ||
│ │ ├── [LightningModule] | ||
│ │ └── [Strategy] | ||
│ │ ├── [for each validation batch] | ||
│ │ │ ├── on_validation_batch_start() | ||
│ │ │ │ ├── [Callbacks] | ||
│ │ │ │ ├── [LightningModule] | ||
│ │ │ │ └── [Strategy] | ||
│ │ │ └── on_validation_batch_end() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ ├── [LightningModule] | ||
│ │ │ └── [Strategy] | ||
│ │ └── [end validation batches] | ||
│ ├── on_validation_epoch_end() | ||
│ │ ├── [Callbacks] | ||
│ │ ├── [LightningModule] | ||
│ │ └── [Strategy] | ||
│ └── on_validation_end() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
├── on_sanity_check_end() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
├── on_train_start() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
├── [Training Epochs Loop] | ||
│ │ | ||
│ ├── on_train_epoch_start() | ||
│ │ ├── [Callbacks] | ||
│ │ └── [LightningModule] | ||
│ │ | ||
│ ├── [Training Batches Loop] | ||
│ │ │ | ||
│ │ ├── on_train_batch_start() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ ├── [LightningModule] | ||
│ │ │ └── [Strategy] | ||
│ │ │ | ||
│ │ ├── on_before_zero_grad() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ └── [LightningModule] | ||
│ │ │ | ||
│ │ ├── [Forward Pass - training_step()] | ||
│ │ │ └── [Strategy only] | ||
Comment on lines
+146
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the
|
||
│ │ │ | ||
│ │ ├── on_before_backward() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ └── [LightningModule] | ||
│ │ │ | ||
│ │ ├── [Backward Pass] | ||
│ │ │ └── [Strategy only] | ||
│ │ │ | ||
│ │ ├── on_after_backward() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ └── [LightningModule] | ||
│ │ │ | ||
│ │ ├── on_before_optimizer_step() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ └── [LightningModule] | ||
Comment on lines
+164
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps also add
|
||
│ │ │ | ||
│ │ ├── [Optimizer Step] | ||
│ │ │ └── [LightningModule only - optimizer_step()] | ||
│ │ │ | ||
│ │ └── on_train_batch_end() | ||
│ │ ├── [Callbacks] | ||
│ │ └── [LightningModule] | ||
│ │ | ||
│ │ [Optional: Validation during training] | ||
│ │ ├── on_validation_start() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ ├── [LightningModule] | ||
│ │ │ └── [Strategy] | ||
│ │ ├── on_validation_epoch_start() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ ├── [LightningModule] | ||
│ │ │ └── [Strategy] | ||
│ │ │ ├── [for each validation batch] | ||
│ │ │ │ ├── on_validation_batch_start() | ||
│ │ │ │ │ ├── [Callbacks] | ||
│ │ │ │ │ ├── [LightningModule] | ||
│ │ │ │ │ └── [Strategy] | ||
│ │ │ │ └── on_validation_batch_end() | ||
│ │ │ │ ├── [Callbacks] | ||
│ │ │ │ ├── [LightningModule] | ||
│ │ │ │ └── [Strategy] | ||
│ │ │ └── [end validation batches] | ||
│ │ ├── on_validation_epoch_end() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ ├── [LightningModule] | ||
│ │ │ └── [Strategy] | ||
│ │ └── on_validation_end() | ||
│ │ ├── [Callbacks] | ||
│ │ ├── [LightningModule] | ||
│ │ └── [Strategy] | ||
│ │ | ||
│ └── on_train_epoch_end() **SPECIAL CASE** | ||
│ ├── [Callbacks - Non-monitoring only] | ||
│ ├── [LightningModule] | ||
│ └── [Callbacks - Monitoring only] | ||
Comment on lines
+204
to
+206
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps |
||
│ | ||
├── [End Training Epochs] | ||
│ | ||
├── on_train_end() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
├── on_fit_end() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
└── teardown(stage="fit") | ||
Comment on lines
+215
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from my personal code inspect, I believe that the strategy.teardown should be called before
|
||
└── [Callbacks only] | ||
Comment on lines
+220
to
+221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that the lightning module also has a
|
||
|
||
*********************** | ||
Testing Loop Hook Order | ||
*********************** | ||
|
||
When running tests with ``trainer.test()``: | ||
|
||
.. code-block:: text | ||
|
||
trainer.test() | ||
│ | ||
├── setup(stage="test") | ||
│ └── [Callbacks only] | ||
├── on_test_start() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
├── [Test Epochs Loop] | ||
│ │ | ||
│ ├── on_test_epoch_start() | ||
│ │ ├── [Callbacks] | ||
│ │ ├── [LightningModule] | ||
│ │ └── [Strategy] | ||
│ │ | ||
│ ├── [Test Batches Loop] | ||
│ │ │ | ||
│ │ ├── on_test_batch_start() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ ├── [LightningModule] | ||
│ │ │ └── [Strategy] | ||
│ │ │ | ||
│ │ └── on_test_batch_end() | ||
│ │ ├── [Callbacks] | ||
│ │ ├── [LightningModule] | ||
│ │ └── [Strategy] | ||
│ │ | ||
│ └── on_test_epoch_end() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
├── on_test_end() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
└── teardown(stage="test") | ||
└── [Callbacks only] | ||
|
||
************************** | ||
Prediction Loop Hook Order | ||
************************** | ||
|
||
When running predictions with ``trainer.predict()``: | ||
|
||
.. code-block:: text | ||
|
||
trainer.predict() | ||
│ | ||
├── setup(stage="predict") | ||
│ └── [Callbacks only] | ||
├── on_predict_start() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
│ | ||
├── [Prediction Epochs Loop] | ||
│ │ | ||
│ ├── on_predict_epoch_start() | ||
│ │ ├── [Callbacks] | ||
│ │ └── [LightningModule] | ||
│ │ | ||
│ ├── [Prediction Batches Loop] | ||
│ │ │ | ||
│ │ ├── on_predict_batch_start() | ||
│ │ │ ├── [Callbacks] | ||
│ │ │ └── [LightningModule] | ||
│ │ │ | ||
│ │ └── on_predict_batch_end() | ||
│ │ ├── [Callbacks] | ||
│ │ └── [LightningModule] | ||
│ │ | ||
│ └── on_predict_epoch_end() | ||
│ ├── [Callbacks] | ||
│ └── [LightningModule] | ||
│ | ||
├── on_predict_end() | ||
│ ├── [Callbacks] | ||
│ ├── [LightningModule] | ||
│ └── [Strategy] | ||
└── teardown(stage="predict") | ||
└── [Callbacks only] |
Uh oh!
There was an error while loading. Please reload this page.