Skip to content

Commit 9664830

Browse files
authored
Merge branch 'master' into fix-trainer-fit-docstring-render-21356
2 parents 07d689f + f7692a6 commit 9664830

File tree

9 files changed

+236
-21
lines changed

9 files changed

+236
-21
lines changed

.github/CONTRIBUTING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ We welcome any useful contribution! For your convenience here's a recommended wo
212212
- [Test README](https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/README.md)
213213
- [CI/CD README](https://github.com/Lightning-AI/pytorch-lightning/tree/master/.github/workflows#readme)
214214

215+
1. Once you have a PR opened (and thereby a PR number), please update the respective changelog for [fabric](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/CHANGELOG.md) or [pytorch](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/CHANGELOG.md) subpackage depending on where you made your changes.
216+
215217
1. When you feel ready for integrating your work, mark your PR "Ready for review".
216218

217219
- Your code should be readable and follow the project's design principles.

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28-
-
28+
- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309))
29+
30+
31+
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
32+
2933

3034

3135
---

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
2626

2727

28+
- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260))
29+
30+
2831
### Changed
2932

3033
- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))

src/lightning/pytorch/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
3333
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
3434
from lightning.pytorch.callbacks.timer import Timer
35-
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
35+
from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging
3636

3737
__all__ = [
3838
"BackboneFinetuning",
@@ -59,5 +59,6 @@
5959
"ThroughputMonitor",
6060
"Timer",
6161
"TQDMProgressBar",
62+
"EMAWeightAveraging",
6263
"WeightAveraging",
6364
]

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any, Optional, Union
2222

2323
import torch
24-
from torch.optim.swa_utils import AveragedModel
24+
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
2525
from typing_extensions import override
2626

2727
import lightning.pytorch as pl
@@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
361361
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
362362
for average_param, current_param in zip(average_params, current_params):
363363
current_param.data.copy_(average_param.data)
364+
365+
366+
class EMAWeightAveraging(WeightAveraging):
367+
"""Exponential Moving Average (EMA) Weight Averaging callback."""
368+
369+
def __init__(
370+
self,
371+
device: Optional[Union[torch.device, str, int]] = None,
372+
use_buffers: bool = True,
373+
decay: float = 0.999,
374+
update_every_n_steps: int = 1,
375+
update_starting_at_step: Optional[int] = None,
376+
update_starting_at_epoch: Optional[int] = None,
377+
**kwargs: Any,
378+
):
379+
super().__init__(
380+
device=device,
381+
use_buffers=use_buffers,
382+
**kwargs,
383+
avg_fn=get_ema_avg_fn(decay=decay),
384+
)
385+
386+
self.update_every_n_steps = update_every_n_steps
387+
self.update_starting_at_step = update_starting_at_step
388+
self.update_starting_at_epoch = update_starting_at_epoch
389+
390+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
391+
"""Decide when to update the model weights.
392+
393+
Args:
394+
step_idx: The current step index.
395+
epoch_idx: The current epoch index.
396+
Returns:
397+
bool: True if the model weights should be updated, False otherwise.
398+
399+
"""
400+
if step_idx is not None:
401+
# Check step-based conditions only if we have a valid step_idx
402+
meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
403+
meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
404+
if meets_step_requirement and meets_step_frequency:
405+
return True
406+
407+
if epoch_idx is not None:
408+
# Check epoch-based condition only if we specify one
409+
meets_epoch_requirement = (
410+
self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch
411+
)
412+
if meets_epoch_requirement:
413+
return True
414+
415+
return False

src/lightning/pytorch/core/hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
6969
"""Called in the training loop before anything happens for that batch.
7070
7171
If you return -1 here, you will skip training for the rest of the current epoch.
72+
Learning rate scheduler will still be stepped at the end of epoch.
7273
7374
Args:
7475
batch: The batched data as it is returned by the training DataLoader.

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -325,30 +325,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
325325
trainer._logger_connector.on_batch_start(batch)
326326

327327
batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
328+
should_skip_rest_of_epoch = False
329+
328330
if batch is None and not using_dataloader_iter:
329331
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
330332
else:
331333
# hook
332334
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
333335
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
334336
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
335-
if response == -1:
336-
self.batch_progress.increment_processed()
337-
raise StopIteration
338-
339-
self.batch_progress.increment_started()
340-
341-
kwargs = (
342-
self._build_kwargs(OrderedDict(), batch, batch_idx)
343-
if not using_dataloader_iter
344-
else OrderedDict(any=dataloader_iter)
345-
)
346-
with trainer.profiler.profile("run_training_batch"):
347-
if trainer.lightning_module.automatic_optimization:
348-
# in automatic optimization, there can only be one optimizer
349-
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
350-
else:
351-
batch_output = self.manual_optimization.run(kwargs)
337+
should_skip_rest_of_epoch = response == -1
338+
# Signal this is the last batch for the current epoch
339+
if should_skip_rest_of_epoch:
340+
self.batch_progress.increment_by(0, is_last_batch=True)
341+
else:
342+
self.batch_progress.increment_started()
343+
344+
kwargs = (
345+
self._build_kwargs(OrderedDict(), batch, batch_idx)
346+
if not using_dataloader_iter
347+
else OrderedDict(any=dataloader_iter)
348+
)
349+
with trainer.profiler.profile("run_training_batch"):
350+
if trainer.lightning_module.automatic_optimization:
351+
# in automatic optimization, there can only be one optimizer
352+
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
353+
else:
354+
batch_output = self.manual_optimization.run(kwargs)
352355

353356
self.batch_progress.increment_processed()
354357

@@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
358361
if self._num_ready_batches_reached():
359362
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
360363

364+
if should_skip_rest_of_epoch:
365+
# Only raise StopIteration now so that the training epoch loop can finish
366+
raise StopIteration
367+
361368
if using_dataloader_iter:
362369
# update the hook kwargs now that the step method might have consumed the iterator
363370
batch = data_fetcher._batch

tests/tests_pytorch/callbacks/test_weight_averaging.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.utils.data import DataLoader, Dataset
2424

2525
from lightning.pytorch import LightningModule, Trainer
26-
from lightning.pytorch.callbacks import WeightAveraging
26+
from lightning.pytorch.callbacks import EMAWeightAveraging, WeightAveraging
2727
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
2828
from tests_pytorch.helpers.runif import RunIf
2929

@@ -329,3 +329,123 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices
329329
callback = EMATestCallback(devices=devices)
330330
_train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs)
331331
return model
332+
333+
334+
@pytest.mark.parametrize(
335+
("strategy", "accelerator", "devices"),
336+
[
337+
("auto", "cpu", 1),
338+
pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)),
339+
],
340+
)
341+
def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices):
342+
"""Test EMAWeightAveraging callback with various update configurations."""
343+
model = TestModel()
344+
dataset = RandomDataset(32, 32)
345+
346+
# Test with default settings (update every step)
347+
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1)
348+
_train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices)
349+
350+
# Verify the average model was created and updated
351+
assert callback._average_model is not None
352+
assert callback._average_model.n_averaged > 0
353+
354+
355+
def test_ema_weight_averaging_step_frequency(tmp_path):
356+
"""Test EMAWeightAveraging with custom step update frequency."""
357+
model = TestModel()
358+
dataset = RandomDataset(32, 32)
359+
360+
# Update every 5 steps
361+
callback = EMAWeightAveraging(decay=0.95, update_every_n_steps=5)
362+
_train(model, dataset, tmp_path, callback)
363+
364+
assert callback._average_model is not None
365+
366+
367+
def test_ema_weight_averaging_starting_step(tmp_path):
368+
"""Test EMAWeightAveraging with delayed start based on steps."""
369+
model = TestModel()
370+
dataset = RandomDataset(32, 32)
371+
372+
# Start updating after step 10
373+
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_step=10)
374+
_train(model, dataset, tmp_path, callback)
375+
376+
assert callback._average_model is not None
377+
378+
379+
def test_ema_weight_averaging_starting_epoch(tmp_path):
380+
"""Test EMAWeightAveraging with delayed start based on epochs."""
381+
model = TestModel()
382+
dataset = RandomDataset(32, 32)
383+
384+
# Start updating after epoch 3
385+
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_epoch=3)
386+
_train(model, dataset, tmp_path, callback)
387+
388+
assert callback._average_model is not None
389+
390+
391+
def test_ema_weight_averaging_should_update(tmp_path):
392+
"""Test the should_update logic of EMAWeightAveraging."""
393+
# Test with step-based updates
394+
callback = EMAWeightAveraging(update_every_n_steps=5, update_starting_at_step=10)
395+
396+
# Before starting step
397+
assert not callback.should_update(step_idx=5)
398+
assert not callback.should_update(step_idx=9)
399+
400+
# At and after starting step, but not on update frequency
401+
assert callback.should_update(step_idx=10) # First update
402+
assert not callback.should_update(step_idx=11)
403+
assert not callback.should_update(step_idx=14)
404+
assert callback.should_update(step_idx=15) # Second update
405+
406+
# Test with epoch-based updates
407+
callback = EMAWeightAveraging(update_starting_at_epoch=2)
408+
409+
assert not callback.should_update(epoch_idx=0)
410+
assert not callback.should_update(epoch_idx=1)
411+
assert callback.should_update(epoch_idx=2)
412+
assert callback.should_update(epoch_idx=3)
413+
414+
415+
def test_ema_weight_averaging_checkpoint_save_load(tmp_path):
416+
"""Test that EMAWeightAveraging correctly saves and loads checkpoints."""
417+
model = TestModel()
418+
model.crash_on_epoch = 2
419+
dataset = RandomDataset(32, 32)
420+
421+
callback = EMAWeightAveraging(decay=0.99, update_every_n_steps=2)
422+
423+
# Train and create checkpoint
424+
_train(model, dataset, tmp_path, callback, will_crash=True)
425+
426+
# Resume from checkpoint
427+
model2 = TestModel()
428+
callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2)
429+
import glob # should be at the top
430+
431+
_train(
432+
model2,
433+
dataset,
434+
tmp_path,
435+
callback2,
436+
checkpoint_path=glob.glob((tmp_path / "checkpoints" / "*.ckpt").as_posix())[0],
437+
)
438+
439+
assert callback2._average_model is not None
440+
441+
442+
@pytest.mark.parametrize("decay", [0.9, 0.99, 0.999, 0.9999])
443+
def test_ema_weight_averaging_decay_values(tmp_path, decay):
444+
"""Test EMAWeightAveraging with different decay values."""
445+
model = TestModel()
446+
dataset = RandomDataset(32, 32)
447+
448+
callback = EMAWeightAveraging(decay=decay, update_every_n_steps=1)
449+
_train(model, dataset, tmp_path, callback)
450+
451+
assert callback._average_model is not None

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def on_train_batch_start(self, batch, batch_idx):
111111
assert trainer.fit_loop.batch_idx == batch_idx_
112112
assert trainer.global_step == batch_idx_ * max_epochs
113113

114+
assert trainer.is_last_batch
115+
114116

115117
def test_should_stop_mid_epoch(tmp_path):
116118
"""Test that training correctly stops mid epoch and that validation is still called at the right time."""
@@ -305,3 +307,26 @@ def test_eval_mode_warning(tmp_path, warn):
305307
w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
306308
]
307309
assert len(eval_warnings) == 0, "Expected no eval mode warnings"
310+
311+
312+
@pytest.mark.parametrize(("max_epochs", "batch_idx_"), [(2, 5), (3, 8)])
313+
def test_lr_updated_on_train_batch_start_returns_minus_one(tmp_path, max_epochs, batch_idx_):
314+
"""Test that when the rest of the epoch is skipped, due to on_train_batch_start returning -1, the learning rate is
315+
still updated when it should, at the end of the epoch."""
316+
317+
class TestModel(BoringModel):
318+
def on_train_batch_start(self, batch, batch_idx):
319+
if batch_idx == batch_idx_:
320+
return -1
321+
return super().on_train_batch_start(batch, batch_idx)
322+
323+
model = TestModel()
324+
init_lr = 0.1
325+
trainer = Trainer(default_root_dir=tmp_path, limit_train_batches=10, max_epochs=max_epochs)
326+
trainer.fit(model)
327+
328+
adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups]
329+
330+
assert len(trainer.lr_scheduler_configs) == 1
331+
assert all(a == adjusted_lr[0] for a in adjusted_lr)
332+
assert init_lr * 0.1**max_epochs == adjusted_lr[0]

0 commit comments

Comments
 (0)