Skip to content

Commit d261e88

Browse files
committed
docs(LightningModule): update docs for .training mode in loops
Update the pseudocode of validation loop according to #18951: > when the validation loop ends, and before switching to training, it > restores the `.training mode` on all submodules to what it was before. and add a corresponding note to `{validate,test,predict}_step`. Additional changes: * Fix incorrect comment in `lightning_module.rst` that `trainer.test(model)` loads the best weights.
1 parent 03635d2 commit d261e88

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

docs/source-pytorch/common/lightning_module.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ Under the hood, Lightning does the following (pseudocode):
286286
# ...
287287
288288
if validate_at_some_point:
289+
# capture .training mode of every submodule
290+
capture_training_mode()
291+
289292
# disable grads + batchnorm + dropout
290293
torch.set_grad_enabled(False)
291294
model.eval()
@@ -295,9 +298,11 @@ Under the hood, Lightning does the following (pseudocode):
295298
val_out = model.validation_step(val_batch, val_batch_idx)
296299
# ----------------- VAL LOOP ---------------
297300
298-
# enable grads + batchnorm + dropout
301+
# enable grads
299302
torch.set_grad_enabled(True)
300-
model.train()
303+
304+
# restore .training mode of every submodule
305+
restore_training_mode()
301306
302307
You can also run just the validation loop on your validation dataloaders by overriding :meth:`~lightning.pytorch.core.LightningModule.validation_step`
303308
and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`.
@@ -368,7 +373,7 @@ The only difference is that the test loop is only called when :meth:`~lightning.
368373
trainer = L.Trainer()
369374
trainer.fit(model=model, train_dataloaders=dataloader)
370375
371-
# automatically loads the best weights for you
376+
# use the current weights
372377
trainer.test(model)
373378
374379
There are two ways to call ``test()``:

src/lightning/pytorch/core/module.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -814,9 +814,10 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
814814
If you don't need to validate you don't need to implement this method.
815815
816816
Note:
817-
When the :meth:`validation_step` is called, the model has been put in eval mode
818-
and PyTorch gradients have been disabled. At the end of validation,
819-
the model goes back to training mode and gradients are enabled.
817+
When the :meth:`validation_step` is called, the model has been put
818+
in eval mode and PyTorch gradients have been disabled. At the end
819+
of the validation epoch, the ``.training`` mode of every submodule
820+
is restored to what it was before and gradients are enabled.
820821
821822
"""
822823

@@ -881,9 +882,10 @@ def test_step(self, batch, batch_idx, dataloader_idx=0):
881882
If you don't need to test you don't need to implement this method.
882883
883884
Note:
884-
When the :meth:`test_step` is called, the model has been put in eval mode and
885-
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
886-
to training mode and gradients are enabled.
885+
When the :meth:`test_step` is called, the model has been put in
886+
eval mode and PyTorch gradients have been disabled. At the end of
887+
the test epoch, the ``.training`` mode of every submodule is
888+
restored to what it was before and gradients are enabled.
887889
888890
"""
889891

@@ -922,6 +924,12 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
922924
trainer = Trainer(accelerator="gpu", devices=2)
923925
predictions = trainer.predict(model, dm)
924926
927+
Note:
928+
When the :meth:`predict_step` is called, the model has been put in
929+
eval mode and PyTorch gradients have been disabled. At the end of
930+
the predict epoch, the ``.training`` mode of every submodule is
931+
restored to what it was before and gradients are enabled.
932+
925933
"""
926934
# For backwards compatibility
927935
batch = kwargs.get("batch", args[0])

0 commit comments

Comments
 (0)