Skip to content

Commit 3d56296

Browse files
SkafteNickiBorda
andauthored
Docs on hook call order (#21120)
* add hook order * add to index * Apply suggestions from code review * BoringModel * testoutput --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]>
1 parent db77fa7 commit 3d56296

File tree

2 files changed

+321
-0
lines changed

2 files changed

+321
-0
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
##########################
2+
Hooks in PyTorch Lightning
3+
##########################
4+
5+
Hooks in Pytorch Lightning allow you to customize the training, validation, and testing logic of your models. They
6+
provide a way to insert custom behavior at specific points during the training process without modifying the core
7+
training loop. There are several categories of hooks available in PyTorch Lightning:
8+
9+
1. **Setup/Teardown Hooks**: Called at the beginning and end of training phases
10+
2. **Training Hooks**: Called during the training loop
11+
3. **Validation Hooks**: Called during validation
12+
4. **Test Hooks**: Called during testing
13+
5. **Prediction Hooks**: Called during prediction
14+
6. **Optimizer Hooks**: Called around optimizer operations
15+
7. **Checkpoint Hooks**: Called during checkpoint save/load operations
16+
8. **Exception Hooks**: Called when exceptions occur
17+
18+
Nearly all hooks can be implemented in three places within your code:
19+
20+
- **LightningModule**: The main module where you define your model and training logic.
21+
- **Callbacks**: Custom classes that can be passed to the Trainer to handle specific events.
22+
- **Strategy**: Custom strategies for distributed training.
23+
24+
Importantly, because logic can be place in the same hook but in different places the call order of hooks is in
25+
important to understand. The following order is always used:
26+
27+
1. Callbacks, called in the order they are passed to the Trainer.
28+
2. ``LightningModule``
29+
3. Strategy
30+
31+
.. testcode::
32+
33+
from lightning.pytorch import Trainer
34+
from lightning.pytorch.callbacks import Callback
35+
from lightning.pytorch.demos import BoringModel
36+
37+
class MyModel(BoringModel):
38+
def on_train_start(self):
39+
print("Model: Training is starting!")
40+
41+
class MyCallback(Callback):
42+
def on_train_start(self, trainer, pl_module):
43+
print("Callback: Training is starting!")
44+
45+
model = MyModel()
46+
callback = MyCallback()
47+
trainer = Trainer(callbacks=[callback], logger=False, max_epochs=1)
48+
trainer.fit(model)
49+
50+
.. testoutput::
51+
:hide:
52+
:options: +ELLIPSIS, +NORMALIZE_WHITESPACE
53+
54+
┏━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
55+
┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃
56+
┡━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
57+
│ 0 │ layer │ Linear │ 66 │ train │ 0 │
58+
└───┴───────┴────────┴────────┴───────┴───────┘
59+
...
60+
Callback: Training is starting!
61+
Model: Training is starting!
62+
Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ...
63+
64+
65+
.. note::
66+
There are a few exceptions to this pattern:
67+
68+
- **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring callbacks
69+
- **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and ``LightningModule`` are called
70+
- Some internal hooks may only call ``LightningModule`` or Strategy
71+
72+
************************
73+
Training Loop Hook Order
74+
************************
75+
76+
The following diagram shows the execution order of hooks during a typical training loop e.g. calling `trainer.fit()`,
77+
with the source of each hook indicated:
78+
79+
.. code-block:: text
80+
81+
Training Process Flow:
82+
83+
trainer.fit()
84+
85+
├── setup(stage="fit")
86+
│ └── [Callbacks only]
87+
88+
├── on_fit_start()
89+
│ ├── [Callbacks]
90+
│ ├── [LightningModule]
91+
│ └── [Strategy]
92+
93+
├── on_sanity_check_start()
94+
│ ├── [Callbacks]
95+
│ ├── [LightningModule]
96+
│ └── [Strategy]
97+
│ ├── on_validation_start()
98+
│ │ ├── [Callbacks]
99+
│ │ ├── [LightningModule]
100+
│ │ └── [Strategy]
101+
│ ├── on_validation_epoch_start()
102+
│ │ ├── [Callbacks]
103+
│ │ ├── [LightningModule]
104+
│ │ └── [Strategy]
105+
│ │ ├── [for each validation batch]
106+
│ │ │ ├── on_validation_batch_start()
107+
│ │ │ │ ├── [Callbacks]
108+
│ │ │ │ ├── [LightningModule]
109+
│ │ │ │ └── [Strategy]
110+
│ │ │ └── on_validation_batch_end()
111+
│ │ │ ├── [Callbacks]
112+
│ │ │ ├── [LightningModule]
113+
│ │ │ └── [Strategy]
114+
│ │ └── [end validation batches]
115+
│ ├── on_validation_epoch_end()
116+
│ │ ├── [Callbacks]
117+
│ │ ├── [LightningModule]
118+
│ │ └── [Strategy]
119+
│ └── on_validation_end()
120+
│ ├── [Callbacks]
121+
│ ├── [LightningModule]
122+
│ └── [Strategy]
123+
├── on_sanity_check_end()
124+
│ ├── [Callbacks]
125+
│ ├── [LightningModule]
126+
│ └── [Strategy]
127+
128+
├── on_train_start()
129+
│ ├── [Callbacks]
130+
│ ├── [LightningModule]
131+
│ └── [Strategy]
132+
133+
├── [Training Epochs Loop]
134+
│ │
135+
│ ├── on_train_epoch_start()
136+
│ │ ├── [Callbacks]
137+
│ │ └── [LightningModule]
138+
│ │
139+
│ ├── [Training Batches Loop]
140+
│ │ │
141+
│ │ ├── on_train_batch_start()
142+
│ │ │ ├── [Callbacks]
143+
│ │ │ ├── [LightningModule]
144+
│ │ │ └── [Strategy]
145+
│ │ │
146+
│ │ ├── on_before_zero_grad()
147+
│ │ │ ├── [Callbacks]
148+
│ │ │ └── [LightningModule]
149+
│ │ │
150+
│ │ ├── [Forward Pass - training_step()]
151+
│ │ │ └── [Strategy only]
152+
│ │ │
153+
│ │ ├── on_before_backward()
154+
│ │ │ ├── [Callbacks]
155+
│ │ │ └── [LightningModule]
156+
│ │ │
157+
│ │ ├── [Backward Pass]
158+
│ │ │ └── [Strategy only]
159+
│ │ │
160+
│ │ ├── on_after_backward()
161+
│ │ │ ├── [Callbacks]
162+
│ │ │ └── [LightningModule]
163+
│ │ │
164+
│ │ ├── on_before_optimizer_step()
165+
│ │ │ ├── [Callbacks]
166+
│ │ │ └── [LightningModule]
167+
│ │ │
168+
│ │ ├── [Optimizer Step]
169+
│ │ │ └── [LightningModule only - optimizer_step()]
170+
│ │ │
171+
│ │ └── on_train_batch_end()
172+
│ │ ├── [Callbacks]
173+
│ │ └── [LightningModule]
174+
│ │
175+
│ │ [Optional: Validation during training]
176+
│ │ ├── on_validation_start()
177+
│ │ │ ├── [Callbacks]
178+
│ │ │ ├── [LightningModule]
179+
│ │ │ └── [Strategy]
180+
│ │ ├── on_validation_epoch_start()
181+
│ │ │ ├── [Callbacks]
182+
│ │ │ ├── [LightningModule]
183+
│ │ │ └── [Strategy]
184+
│ │ │ ├── [for each validation batch]
185+
│ │ │ │ ├── on_validation_batch_start()
186+
│ │ │ │ │ ├── [Callbacks]
187+
│ │ │ │ │ ├── [LightningModule]
188+
│ │ │ │ │ └── [Strategy]
189+
│ │ │ │ └── on_validation_batch_end()
190+
│ │ │ │ ├── [Callbacks]
191+
│ │ │ │ ├── [LightningModule]
192+
│ │ │ │ └── [Strategy]
193+
│ │ │ └── [end validation batches]
194+
│ │ ├── on_validation_epoch_end()
195+
│ │ │ ├── [Callbacks]
196+
│ │ │ ├── [LightningModule]
197+
│ │ │ └── [Strategy]
198+
│ │ └── on_validation_end()
199+
│ │ ├── [Callbacks]
200+
│ │ ├── [LightningModule]
201+
│ │ └── [Strategy]
202+
│ │
203+
│ └── on_train_epoch_end() **SPECIAL CASE**
204+
│ ├── [Callbacks - Non-monitoring only]
205+
│ ├── [LightningModule]
206+
│ └── [Callbacks - Monitoring only]
207+
208+
├── [End Training Epochs]
209+
210+
├── on_train_end()
211+
│ ├── [Callbacks]
212+
│ ├── [LightningModule]
213+
│ └── [Strategy]
214+
215+
├── on_fit_end()
216+
│ ├── [Callbacks]
217+
│ ├── [LightningModule]
218+
│ └── [Strategy]
219+
220+
└── teardown(stage="fit")
221+
└── [Callbacks only]
222+
223+
***********************
224+
Testing Loop Hook Order
225+
***********************
226+
227+
When running tests with ``trainer.test()``:
228+
229+
.. code-block:: text
230+
231+
trainer.test()
232+
233+
├── setup(stage="test")
234+
│ └── [Callbacks only]
235+
├── on_test_start()
236+
│ ├── [Callbacks]
237+
│ ├── [LightningModule]
238+
│ └── [Strategy]
239+
240+
├── [Test Epochs Loop]
241+
│ │
242+
│ ├── on_test_epoch_start()
243+
│ │ ├── [Callbacks]
244+
│ │ ├── [LightningModule]
245+
│ │ └── [Strategy]
246+
│ │
247+
│ ├── [Test Batches Loop]
248+
│ │ │
249+
│ │ ├── on_test_batch_start()
250+
│ │ │ ├── [Callbacks]
251+
│ │ │ ├── [LightningModule]
252+
│ │ │ └── [Strategy]
253+
│ │ │
254+
│ │ └── on_test_batch_end()
255+
│ │ ├── [Callbacks]
256+
│ │ ├── [LightningModule]
257+
│ │ └── [Strategy]
258+
│ │
259+
│ └── on_test_epoch_end()
260+
│ ├── [Callbacks]
261+
│ ├── [LightningModule]
262+
│ └── [Strategy]
263+
264+
├── on_test_end()
265+
│ ├── [Callbacks]
266+
│ ├── [LightningModule]
267+
│ └── [Strategy]
268+
└── teardown(stage="test")
269+
└── [Callbacks only]
270+
271+
**************************
272+
Prediction Loop Hook Order
273+
**************************
274+
275+
When running predictions with ``trainer.predict()``:
276+
277+
.. code-block:: text
278+
279+
trainer.predict()
280+
281+
├── setup(stage="predict")
282+
│ └── [Callbacks only]
283+
├── on_predict_start()
284+
│ ├── [Callbacks]
285+
│ ├── [LightningModule]
286+
│ └── [Strategy]
287+
288+
├── [Prediction Epochs Loop]
289+
│ │
290+
│ ├── on_predict_epoch_start()
291+
│ │ ├── [Callbacks]
292+
│ │ └── [LightningModule]
293+
│ │
294+
│ ├── [Prediction Batches Loop]
295+
│ │ │
296+
│ │ ├── on_predict_batch_start()
297+
│ │ │ ├── [Callbacks]
298+
│ │ │ └── [LightningModule]
299+
│ │ │
300+
│ │ └── on_predict_batch_end()
301+
│ │ ├── [Callbacks]
302+
│ │ └── [LightningModule]
303+
│ │
304+
│ └── on_predict_epoch_end()
305+
│ ├── [Callbacks]
306+
│ └── [LightningModule]
307+
308+
├── on_predict_end()
309+
│ ├── [Callbacks]
310+
│ ├── [LightningModule]
311+
│ └── [Strategy]
312+
└── teardown(stage="predict")
313+
└── [Callbacks only]

docs/source-pytorch/glossary/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FSDP <../advanced/model_parallel/fsdp>
2121
GPU <../accelerators/gpu>
2222
Half precision <../common/precision>
23+
Hooks <../common/hooks>
2324
HPU <../integrations/hpu/index>
2425
Inference <../deploy/production_intermediate>
2526
Lightning CLI <../cli/lightning_cli>
@@ -179,6 +180,13 @@ Glossary
179180
:button_link: ../common/precision.html
180181
:height: 100
181182

183+
.. displayitem::
184+
:header: Hooks
185+
:description: How to customize the training, validation, and testing logic
186+
:col_css: col-md-12
187+
:button_link: ../common/hooks.html
188+
:height: 100
189+
182190
.. displayitem::
183191
:header: HPU
184192
:description: Habana Gaudi AI Processor Unit for faster training

0 commit comments

Comments
 (0)