Skip to content

Commit 2a827f3

Browse files
authored
Docs fixes (#19529)
1 parent 2e512d4 commit 2a827f3

File tree

5 files changed

+22
-23
lines changed

5 files changed

+22
-23
lines changed

docs/source-pytorch/common/checkpointing_intermediate.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,11 @@ In distributed training cases where a model is running across many machines, Lig
167167
trainer = Trainer(strategy="ddp")
168168
model = MyLightningModule(hparams)
169169
trainer.fit(model)
170+
170171
# Saves only on the main process
172+
# Handles strategy-specific saving logic like XLA, FSDP, DeepSpeed etc.
171173
trainer.save_checkpoint("example.ckpt")
172174
173-
Not using :meth:`~lightning.pytorch.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality.
174-
If using custom saving functions cannot be avoided, we recommend using the :func:`~lightning.pytorch.utilities.rank_zero.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
175-
model parallel distributed strategies such as deepspeed or sharded training.
175+
176+
By using :meth:`~lightning.pytorch.trainer.trainer.Trainer.save_checkpoint` instead of ``torch.save``, you make your code agnostic to the distributed training strategy being used.
177+
It will ensure that checkpoints are saved correctly in a multi-process setting, avoiding race conditions, deadlocks and other common issues that normally require boilerplate code to handle properly.

docs/source-pytorch/starter/installation.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ Install lightning inside a virtual env or conda environment with pip
1616
1717
python -m pip install lightning
1818
19-
--------------
19+
20+
----
21+
2022

2123
******************
2224
Install with Conda
@@ -66,17 +68,17 @@ Install future patch releases from the source. Note that the patch release conta
6668
^^^^^^^^^^^^^^^^^^^^^^
6769
Custom PyTorch Version
6870
^^^^^^^^^^^^^^^^^^^^^^
69-
To use any PyTorch version visit the `PyTorch Installation Page <https://pytorch.org/get-started/locally/#start-locally>`_.
7071

72+
To use any PyTorch version visit the `PyTorch Installation Page <https://pytorch.org/get-started/locally/#start-locally>`_.
7173
You can find the list of supported PyTorch versions in our :ref:`compatibility matrix <versioning:Compatibility matrix>`.
7274

7375
----
7476

7577

7678
*******************************************
77-
Optimized for ML workflows (lightning Apps)
79+
Optimized for ML workflows (Lightning Apps)
7880
*******************************************
79-
If you are deploying workflows built with Lightning in production and require fewer dependencies, try using the optimized `lightning[apps]` package:
81+
If you are deploying workflows built with Lightning in production and require fewer dependencies, try using the optimized ``lightning[apps]`` package:
8082

8183
.. code-block:: bash
8284

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,12 @@ class ModelCheckpoint(Checkpoint):
8989
in a deterministic manner. Default: ``None``.
9090
save_top_k: if ``save_top_k == k``,
9191
the best k models according to the quantity monitored will be saved.
92-
if ``save_top_k == 0``, no models are saved.
93-
if ``save_top_k == -1``, all models are saved.
92+
If ``save_top_k == 0``, no models are saved.
93+
If ``save_top_k == -1``, all models are saved.
9494
Please note that the monitors are checked every ``every_n_epochs`` epochs.
95-
if ``save_top_k >= 2`` and the callback is called multiple
96-
times inside an epoch, the name of the saved file will be
97-
appended with a version count starting with ``v1``
98-
unless ``enable_version_counter`` is set to False.
95+
If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
96+
unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
97+
collisions unless ``enable_version_counter`` is set to False.
9998
mode: one of {min, max}.
10099
If ``save_top_k != 0``, the decision to overwrite the current save file is made
101100
based on either the maximization or the minimization of the monitored quantity.

src/lightning/pytorch/core/hooks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -
8585
batch: The batched data as it is returned by the training DataLoader.
8686
batch_idx: the index of the batch
8787
88+
Note:
89+
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
90+
loss returned from ``training_step``.
91+
8892
"""
8993

9094
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:

src/lightning/pytorch/core/module.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,20 +1285,12 @@ def optimizer_step(
12851285
12861286
Examples::
12871287
1288-
# DEFAULT
12891288
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
1290-
optimizer.step(closure=optimizer_closure)
1289+
# Add your custom logic to run directly before `optimizer.step()`
12911290
1292-
# Learning rate warm-up
1293-
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
1294-
# update params
12951291
optimizer.step(closure=optimizer_closure)
12961292
1297-
# manually warm up lr without a scheduler
1298-
if self.trainer.global_step < 500:
1299-
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
1300-
for pg in optimizer.param_groups:
1301-
pg["lr"] = lr_scale * self.learning_rate
1293+
# Add your custom logic to run directly after `optimizer.step()`
13021294
13031295
"""
13041296
optimizer.step(closure=optimizer_closure)

0 commit comments

Comments
 (0)