Skip to content

Commit 6fceec0

Browse files
authored
Merge branch 'master' into integrate_package
2 parents aa9fc0b + 8ac4843 commit 6fceec0

File tree

13 files changed

+562
-45
lines changed

13 files changed

+562
-45
lines changed

.github/CONTRIBUTING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ We welcome any useful contribution! For your convenience here's a recommended wo
212212
- [Test README](https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/README.md)
213213
- [CI/CD README](https://github.com/Lightning-AI/pytorch-lightning/tree/master/.github/workflows#readme)
214214

215+
1. Once you have a PR opened (and thereby a PR number), please update the respective changelog for [fabric](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/CHANGELOG.md) or [pytorch](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/CHANGELOG.md) subpackage depending on where you made your changes.
216+
215217
1. When you feel ready for integrating your work, mark your PR "Ready for review".
216218

217219
- Your code should be readable and follow the project's design principles.

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28-
-
28+
- Fixed issue in detecting MPIEnvironment with partial mpi4py installation ([#21353](https://github.com/Lightning-AI/pytorch-lightning/pull/21353))
29+
30+
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
2931

3032

3133
---

src/lightning/fabric/plugins/environments/mpi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def detect() -> bool:
7373
if not _MPI4PY_AVAILABLE:
7474
return False
7575

76-
from mpi4py import MPI
76+
try:
77+
# mpi4py may be installed without MPI being present
78+
from mpi4py import MPI
79+
except ImportError:
80+
return False
7781

7882
return MPI.COMM_WORLD.Get_size() > 1
7983

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
2626

2727

28+
- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260))
29+
30+
2831
### Changed
2932

3033
- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))

src/lightning/pytorch/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
3333
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
3434
from lightning.pytorch.callbacks.timer import Timer
35-
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
35+
from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging
3636

3737
__all__ = [
3838
"BackboneFinetuning",
@@ -59,5 +59,6 @@
5959
"ThroughputMonitor",
6060
"Timer",
6161
"TQDMProgressBar",
62+
"EMAWeightAveraging",
6263
"WeightAveraging",
6364
]

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 137 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,56 @@
4848

4949

5050
class ModelCheckpoint(Checkpoint):
51-
r"""Save the model periodically by monitoring a quantity. Every metric logged with
52-
:meth:`~lightning.pytorch.core.LightningModule.log` or :meth:`~lightning.pytorch.core.LightningModule.log_dict` is
53-
a candidate for the monitor key. For more information, see :ref:`checkpointing`.
51+
r"""Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
52+
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
53+
checkpoint.
5454
5555
After training finishes, use :attr:`best_model_path` to retrieve the path to the
56-
best checkpoint file and :attr:`best_model_score` to retrieve its score.
56+
best checkpoint file and :attr:`best_model_score` to get its score.
57+
58+
.. note::
59+
When using manual optimization with ``every_n_train_steps``, you should save the model state
60+
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
61+
the pre-optimization state. Example:
62+
63+
.. code-block:: python
64+
65+
def training_step(self, batch, batch_idx):
66+
# ... forward pass, loss calculation, backward pass ...
67+
68+
# Save model state before optimization
69+
if not hasattr(self, 'saved_models'):
70+
self.saved_models = {}
71+
self.saved_models[batch_idx] = {
72+
k: v.detach().clone()
73+
for k, v in self.layer.state_dict().items()
74+
}
75+
76+
# Then perform optimization
77+
optimizer.zero_grad()
78+
self.manual_backward(loss)
79+
optimizer.step()
80+
81+
# Optional: Clean up old states to save memory
82+
if batch_idx > 10: # Keep last 10 states
83+
del self.saved_models[batch_idx - 10]
5784
5885
Args:
59-
dirpath: directory to save the model file.
86+
dirpath: Directory to save the model file.
87+
Example: ``dirpath='my/path/'``.
6088
61-
Example::
89+
.. warning::
90+
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
91+
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
92+
in your training loop as shown in the example above.
6293
63-
# custom path
64-
# saves a file like: my/path/epoch=0-step=10.ckpt
65-
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
94+
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
95+
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
96+
in memory, and do not save anything to disk.
6697
67-
By default, dirpath is ``None`` and will be set at runtime to the location
68-
specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s
69-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument,
70-
and if the Trainer uses a logger, the path will also contain logger name and version.
98+
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
99+
If no name is provided, it will be ``None`` and the checkpoint will be saved to
100+
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
71101
72102
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
73103
@@ -109,10 +139,15 @@ class ModelCheckpoint(Checkpoint):
109139
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
110140
save_weights_only: if ``True``, then only the model's weights will be
111141
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
112-
every_n_train_steps: Number of training steps between checkpoints.
113-
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
114-
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
115-
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
142+
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
143+
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
144+
no checkpoints
145+
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
146+
147+
.. note::
148+
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
149+
To save the model state before the optimizer step, you need to save the model state in your
150+
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
116151
train_time_interval: Checkpoints are monitored at the specified time interval.
117152
For all practical purposes, this cannot be smaller than the amount
118153
of time it takes to process a single training batch. This is not
@@ -311,9 +346,85 @@ def on_train_batch_end(
311346
batch_idx: int,
312347
) -> None:
313348
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
314-
# Do not return early here because we may need to set deferral flags even
315-
# if a save already happened at this global step. We'll enforce the skip
316-
# just before actually saving below.
349+
# For manual optimization, we need to handle saving differently
350+
if not pl_module.automatic_optimization:
351+
# Skip if we don't need to save at this step
352+
if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0):
353+
return
354+
355+
# Check if we should skip due to trainer/callback state
356+
if self._should_skip_saving_checkpoint(trainer):
357+
return
358+
359+
# Get monitor candidates and check if we have the monitored metric
360+
monitor_candidates = self._monitor_candidates(trainer)
361+
if self.monitor is not None and self.monitor not in monitor_candidates:
362+
self._defer_save_until_validation = True
363+
return
364+
365+
# For manual optimization, we save the model state that was captured in training_step
366+
# before the optimizer step. The test case saves this state in model.saved_models.
367+
if (
368+
hasattr(pl_module, "saved_models")
369+
and isinstance(pl_module.saved_models, dict)
370+
and pl_module.saved_models
371+
and hasattr(pl_module, "layer")
372+
and isinstance(pl_module.layer, torch.nn.Module)
373+
):
374+
# Get the latest saved state
375+
saved_models = pl_module.saved_models
376+
if not saved_models: # Check if dictionary is not empty
377+
return
378+
379+
latest_step = max(saved_models.keys())
380+
# Save the checkpoint with the pre-optimization state
381+
with torch.no_grad():
382+
# Save the current state
383+
original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()}
384+
try:
385+
# Restore the pre-optimization state
386+
saved_state = saved_models[latest_step]
387+
if not isinstance(saved_state, dict):
388+
raise TypeError("Saved model state must be a dictionary")
389+
390+
pl_module.layer.load_state_dict(saved_state)
391+
# Save the checkpoint
392+
self._save_topk_checkpoint(trainer, monitor_candidates)
393+
self._save_last_checkpoint(trainer, monitor_candidates)
394+
self._last_time_checked = time.monotonic()
395+
finally:
396+
# Restore the original state
397+
pl_module.layer.load_state_dict(original_state)
398+
else:
399+
# Fallback to default behavior if no saved state is available
400+
if not pl_module.automatic_optimization and trainer.is_global_zero:
401+
rank_zero_warn(
402+
"Using ModelCheckpoint with manual optimization and every_n_train_steps, but no "
403+
"pre-optimization state was saved. The checkpoint will contain the model state "
404+
"AFTER optimization. To save the pre-optimization state, save the model state in "
405+
"training_step before "
406+
"optimizer.step(). "
407+
"Example:\n"
408+
"def training_step(self, batch, batch_idx):\n"
409+
" # ... forward pass, loss calculation, backward pass ...\n"
410+
" # Save model state before optimization\n"
411+
" if not hasattr(self, 'saved_models'):\n"
412+
" self.saved_models = {}\n"
413+
" self.saved_models[batch_idx] = {\n"
414+
" k: v.detach().clone() for k, v in self.layer.state_dict().items()\n"
415+
" }\n"
416+
" # Then perform optimization\n"
417+
" optimizer.zero_grad()\n"
418+
" self.manual_backward(loss)\n"
419+
" optimizer.step()",
420+
category=UserWarning,
421+
)
422+
self._save_topk_checkpoint(trainer, monitor_candidates)
423+
self._save_last_checkpoint(trainer, monitor_candidates)
424+
self._last_time_checked = time.monotonic()
425+
return
426+
427+
# Original logic for automatic optimization
317428
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
318429
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
319430

@@ -480,8 +591,13 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[
480591
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
481592

482593
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
483-
trainer.save_checkpoint(filepath, self.save_weights_only)
594+
"""Save the checkpoint to the given filepath.
484595
596+
For manual optimization, we rely on the fact that the model's training_step method saves the model state before
597+
the optimizer step, so we can use that state directly.
598+
599+
"""
600+
trainer.save_checkpoint(filepath, self.save_weights_only)
485601
self._last_global_step_saved = trainer.global_step
486602
self._last_checkpoint_saved = filepath
487603

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any, Optional, Union
2222

2323
import torch
24-
from torch.optim.swa_utils import AveragedModel
24+
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
2525
from typing_extensions import override
2626

2727
import lightning.pytorch as pl
@@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
361361
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
362362
for average_param, current_param in zip(average_params, current_params):
363363
current_param.data.copy_(average_param.data)
364+
365+
366+
class EMAWeightAveraging(WeightAveraging):
367+
"""Exponential Moving Average (EMA) Weight Averaging callback."""
368+
369+
def __init__(
370+
self,
371+
device: Optional[Union[torch.device, str, int]] = None,
372+
use_buffers: bool = True,
373+
decay: float = 0.999,
374+
update_every_n_steps: int = 1,
375+
update_starting_at_step: Optional[int] = None,
376+
update_starting_at_epoch: Optional[int] = None,
377+
**kwargs: Any,
378+
):
379+
super().__init__(
380+
device=device,
381+
use_buffers=use_buffers,
382+
**kwargs,
383+
avg_fn=get_ema_avg_fn(decay=decay),
384+
)
385+
386+
self.update_every_n_steps = update_every_n_steps
387+
self.update_starting_at_step = update_starting_at_step
388+
self.update_starting_at_epoch = update_starting_at_epoch
389+
390+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
391+
"""Decide when to update the model weights.
392+
393+
Args:
394+
step_idx: The current step index.
395+
epoch_idx: The current epoch index.
396+
Returns:
397+
bool: True if the model weights should be updated, False otherwise.
398+
399+
"""
400+
if step_idx is not None:
401+
# Check step-based conditions only if we have a valid step_idx
402+
meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
403+
meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
404+
if meets_step_requirement and meets_step_frequency:
405+
return True
406+
407+
if epoch_idx is not None:
408+
# Check epoch-based condition only if we specify one
409+
meets_epoch_requirement = (
410+
self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch
411+
)
412+
if meets_epoch_requirement:
413+
return True
414+
415+
return False

src/lightning/pytorch/core/hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
6969
"""Called in the training loop before anything happens for that batch.
7070
7171
If you return -1 here, you will skip training for the rest of the current epoch.
72+
Learning rate scheduler will still be stepped at the end of epoch.
7273
7374
Args:
7475
batch: The batched data as it is returned by the training DataLoader.

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -325,30 +325,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
325325
trainer._logger_connector.on_batch_start(batch)
326326

327327
batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
328+
should_skip_rest_of_epoch = False
329+
328330
if batch is None and not using_dataloader_iter:
329331
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
330332
else:
331333
# hook
332334
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
333335
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
334336
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
335-
if response == -1:
336-
self.batch_progress.increment_processed()
337-
raise StopIteration
338-
339-
self.batch_progress.increment_started()
340-
341-
kwargs = (
342-
self._build_kwargs(OrderedDict(), batch, batch_idx)
343-
if not using_dataloader_iter
344-
else OrderedDict(any=dataloader_iter)
345-
)
346-
with trainer.profiler.profile("run_training_batch"):
347-
if trainer.lightning_module.automatic_optimization:
348-
# in automatic optimization, there can only be one optimizer
349-
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
350-
else:
351-
batch_output = self.manual_optimization.run(kwargs)
337+
should_skip_rest_of_epoch = response == -1
338+
# Signal this is the last batch for the current epoch
339+
if should_skip_rest_of_epoch:
340+
self.batch_progress.increment_by(0, is_last_batch=True)
341+
else:
342+
self.batch_progress.increment_started()
343+
344+
kwargs = (
345+
self._build_kwargs(OrderedDict(), batch, batch_idx)
346+
if not using_dataloader_iter
347+
else OrderedDict(any=dataloader_iter)
348+
)
349+
with trainer.profiler.profile("run_training_batch"):
350+
if trainer.lightning_module.automatic_optimization:
351+
# in automatic optimization, there can only be one optimizer
352+
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
353+
else:
354+
batch_output = self.manual_optimization.run(kwargs)
352355

353356
self.batch_progress.increment_processed()
354357

@@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
358361
if self._num_ready_batches_reached():
359362
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
360363

364+
if should_skip_rest_of_epoch:
365+
# Only raise StopIteration now so that the training epoch loop can finish
366+
raise StopIteration
367+
361368
if using_dataloader_iter:
362369
# update the hook kwargs now that the step method might have consumed the iterator
363370
batch = data_fetcher._batch

0 commit comments

Comments
 (0)