Skip to content

Commit 80ee80c

Browse files
authored
Merge branch 'master' into fix/fsdp-mixed-precision
2 parents aa9f6cb + 8ac4843 commit 80ee80c

File tree

4 files changed

+326
-24
lines changed

4 files changed

+326
-24
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28-
- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309))
29-
28+
- Fixed issue in detecting MPIEnvironment with partial mpi4py installation ([#21353](https://github.com/Lightning-AI/pytorch-lightning/pull/21353))
3029

3130
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
3231

3332

33+
3434
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
3535

3636

src/lightning/fabric/plugins/environments/mpi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def detect() -> bool:
7373
if not _MPI4PY_AVAILABLE:
7474
return False
7575

76-
from mpi4py import MPI
76+
try:
77+
# mpi4py may be installed without MPI being present
78+
from mpi4py import MPI
79+
except ImportError:
80+
return False
7781

7882
return MPI.COMM_WORLD.Get_size() > 1
7983

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 137 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,56 @@
4848

4949

5050
class ModelCheckpoint(Checkpoint):
51-
r"""Save the model periodically by monitoring a quantity. Every metric logged with
52-
:meth:`~lightning.pytorch.core.LightningModule.log` or :meth:`~lightning.pytorch.core.LightningModule.log_dict` is
53-
a candidate for the monitor key. For more information, see :ref:`checkpointing`.
51+
r"""Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
52+
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
53+
checkpoint.
5454
5555
After training finishes, use :attr:`best_model_path` to retrieve the path to the
56-
best checkpoint file and :attr:`best_model_score` to retrieve its score.
56+
best checkpoint file and :attr:`best_model_score` to get its score.
57+
58+
.. note::
59+
When using manual optimization with ``every_n_train_steps``, you should save the model state
60+
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
61+
the pre-optimization state. Example:
62+
63+
.. code-block:: python
64+
65+
def training_step(self, batch, batch_idx):
66+
# ... forward pass, loss calculation, backward pass ...
67+
68+
# Save model state before optimization
69+
if not hasattr(self, 'saved_models'):
70+
self.saved_models = {}
71+
self.saved_models[batch_idx] = {
72+
k: v.detach().clone()
73+
for k, v in self.layer.state_dict().items()
74+
}
75+
76+
# Then perform optimization
77+
optimizer.zero_grad()
78+
self.manual_backward(loss)
79+
optimizer.step()
80+
81+
# Optional: Clean up old states to save memory
82+
if batch_idx > 10: # Keep last 10 states
83+
del self.saved_models[batch_idx - 10]
5784
5885
Args:
59-
dirpath: directory to save the model file.
86+
dirpath: Directory to save the model file.
87+
Example: ``dirpath='my/path/'``.
6088
61-
Example::
89+
.. warning::
90+
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
91+
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
92+
in your training loop as shown in the example above.
6293
63-
# custom path
64-
# saves a file like: my/path/epoch=0-step=10.ckpt
65-
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
94+
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
95+
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
96+
in memory, and do not save anything to disk.
6697
67-
By default, dirpath is ``None`` and will be set at runtime to the location
68-
specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s
69-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument,
70-
and if the Trainer uses a logger, the path will also contain logger name and version.
98+
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
99+
If no name is provided, it will be ``None`` and the checkpoint will be saved to
100+
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
71101
72102
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
73103
@@ -109,10 +139,15 @@ class ModelCheckpoint(Checkpoint):
109139
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
110140
save_weights_only: if ``True``, then only the model's weights will be
111141
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
112-
every_n_train_steps: Number of training steps between checkpoints.
113-
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
114-
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
115-
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
142+
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
143+
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
144+
no checkpoints
145+
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
146+
147+
.. note::
148+
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
149+
To save the model state before the optimizer step, you need to save the model state in your
150+
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
116151
train_time_interval: Checkpoints are monitored at the specified time interval.
117152
For all practical purposes, this cannot be smaller than the amount
118153
of time it takes to process a single training batch. This is not
@@ -311,9 +346,85 @@ def on_train_batch_end(
311346
batch_idx: int,
312347
) -> None:
313348
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
314-
# Do not return early here because we may need to set deferral flags even
315-
# if a save already happened at this global step. We'll enforce the skip
316-
# just before actually saving below.
349+
# For manual optimization, we need to handle saving differently
350+
if not pl_module.automatic_optimization:
351+
# Skip if we don't need to save at this step
352+
if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0):
353+
return
354+
355+
# Check if we should skip due to trainer/callback state
356+
if self._should_skip_saving_checkpoint(trainer):
357+
return
358+
359+
# Get monitor candidates and check if we have the monitored metric
360+
monitor_candidates = self._monitor_candidates(trainer)
361+
if self.monitor is not None and self.monitor not in monitor_candidates:
362+
self._defer_save_until_validation = True
363+
return
364+
365+
# For manual optimization, we save the model state that was captured in training_step
366+
# before the optimizer step. The test case saves this state in model.saved_models.
367+
if (
368+
hasattr(pl_module, "saved_models")
369+
and isinstance(pl_module.saved_models, dict)
370+
and pl_module.saved_models
371+
and hasattr(pl_module, "layer")
372+
and isinstance(pl_module.layer, torch.nn.Module)
373+
):
374+
# Get the latest saved state
375+
saved_models = pl_module.saved_models
376+
if not saved_models: # Check if dictionary is not empty
377+
return
378+
379+
latest_step = max(saved_models.keys())
380+
# Save the checkpoint with the pre-optimization state
381+
with torch.no_grad():
382+
# Save the current state
383+
original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()}
384+
try:
385+
# Restore the pre-optimization state
386+
saved_state = saved_models[latest_step]
387+
if not isinstance(saved_state, dict):
388+
raise TypeError("Saved model state must be a dictionary")
389+
390+
pl_module.layer.load_state_dict(saved_state)
391+
# Save the checkpoint
392+
self._save_topk_checkpoint(trainer, monitor_candidates)
393+
self._save_last_checkpoint(trainer, monitor_candidates)
394+
self._last_time_checked = time.monotonic()
395+
finally:
396+
# Restore the original state
397+
pl_module.layer.load_state_dict(original_state)
398+
else:
399+
# Fallback to default behavior if no saved state is available
400+
if not pl_module.automatic_optimization and trainer.is_global_zero:
401+
rank_zero_warn(
402+
"Using ModelCheckpoint with manual optimization and every_n_train_steps, but no "
403+
"pre-optimization state was saved. The checkpoint will contain the model state "
404+
"AFTER optimization. To save the pre-optimization state, save the model state in "
405+
"training_step before "
406+
"optimizer.step(). "
407+
"Example:\n"
408+
"def training_step(self, batch, batch_idx):\n"
409+
" # ... forward pass, loss calculation, backward pass ...\n"
410+
" # Save model state before optimization\n"
411+
" if not hasattr(self, 'saved_models'):\n"
412+
" self.saved_models = {}\n"
413+
" self.saved_models[batch_idx] = {\n"
414+
" k: v.detach().clone() for k, v in self.layer.state_dict().items()\n"
415+
" }\n"
416+
" # Then perform optimization\n"
417+
" optimizer.zero_grad()\n"
418+
" self.manual_backward(loss)\n"
419+
" optimizer.step()",
420+
category=UserWarning,
421+
)
422+
self._save_topk_checkpoint(trainer, monitor_candidates)
423+
self._save_last_checkpoint(trainer, monitor_candidates)
424+
self._last_time_checked = time.monotonic()
425+
return
426+
427+
# Original logic for automatic optimization
317428
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
318429
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
319430

@@ -480,8 +591,13 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[
480591
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
481592

482593
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
483-
trainer.save_checkpoint(filepath, self.save_weights_only)
594+
"""Save the checkpoint to the given filepath.
484595
596+
For manual optimization, we rely on the fact that the model's training_step method saves the model state before
597+
the optimizer step, so we can use that state directly.
598+
599+
"""
600+
trainer.save_checkpoint(filepath, self.save_weights_only)
485601
self._last_global_step_saved = trainer.global_step
486602
self._last_checkpoint_saved = filepath
487603

0 commit comments

Comments
 (0)