Skip to content

Commit 0df63f7

Browse files
authored
Merge branch 'master' into update-pyproject-py310
2 parents 0e4ed01 + f7692a6 commit 0df63f7

File tree

13 files changed

+249
-39
lines changed

13 files changed

+249
-39
lines changed

.github/CONTRIBUTING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ We welcome any useful contribution! For your convenience here's a recommended wo
212212
- [Test README](https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/README.md)
213213
- [CI/CD README](https://github.com/Lightning-AI/pytorch-lightning/tree/master/.github/workflows#readme)
214214

215+
1. Once you have a PR opened (and thereby a PR number), please update the respective changelog for [fabric](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/CHANGELOG.md) or [pytorch](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/CHANGELOG.md) subpackage depending on where you made your changes.
216+
215217
1. When you feel ready for integrating your work, mark your PR "Ready for review".
216218

217219
- Your code should be readable and follow the project's design principles.

.github/workflows/ci-tests-fabric.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,20 @@ jobs:
4040
matrix:
4141
os: [macOS-14, ubuntu-22.04, windows-2022]
4242
config:
43-
# only run PyTorch latest
43+
# Test unified "lightning" package with PyTorch 2.1-2.5
4444
- { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4545
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
4646
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
4747
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
4848
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
4949

50-
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
50+
# Test "fabric" package with PyTorch 2.6-2.9
5151
- { pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" }
52-
53-
# "fabric" installs the standalone package
54-
- { pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
55-
56-
# adding recently cut Torch 2.7 - FUTURE
52+
- { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" }
5753
- { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
54+
- { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.9" }
5855

59-
# "oldest" versions tests, only on minimum Python
56+
# Test minimum supported versions (oldest)
6057
- { pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
6158
timeout-minutes: 25 # because of building grpcio on Mac
6259
env:

.github/workflows/ci-tests-pytorch.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,20 @@ jobs:
4444
matrix:
4545
os: [macOS-14, ubuntu-22.04, windows-2022]
4646
config:
47-
# only run PyTorch latest
47+
# Test unified "lightning" package with PyTorch 2.1-2.5
4848
- { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4949
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
5050
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5151
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
5252
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
5353

54-
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
54+
# Test "pytorch" package with PyTorch 2.6-2.9
5555
- { pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" }
56-
57-
# "pytorch" installs the standalone package
58-
- { pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
59-
60-
# adding recently cut Torch 2.7 - FUTURE
56+
- { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" }
6157
- { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
58+
- { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.9" }
6259

63-
# "oldest" versions tests, only on minimum Python
60+
# Test minimum supported versions (oldest)
6461
- { pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
6562
timeout-minutes: 50
6663
env:

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28-
-
28+
- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309))
29+
30+
31+
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
32+
2933

3034

3135
---

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
2626

2727

28+
- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260))
29+
30+
2831
### Changed
2932

3033
- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))

src/lightning/pytorch/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
3333
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
3434
from lightning.pytorch.callbacks.timer import Timer
35-
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
35+
from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging
3636

3737
__all__ = [
3838
"BackboneFinetuning",
@@ -59,5 +59,6 @@
5959
"ThroughputMonitor",
6060
"Timer",
6161
"TQDMProgressBar",
62+
"EMAWeightAveraging",
6263
"WeightAveraging",
6364
]

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any
2222

2323
import torch
24-
from torch.optim.swa_utils import AveragedModel
24+
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
2525
from typing_extensions import override
2626

2727
import lightning.pytorch as pl
@@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
361361
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
362362
for average_param, current_param in zip(average_params, current_params):
363363
current_param.data.copy_(average_param.data)
364+
365+
366+
class EMAWeightAveraging(WeightAveraging):
367+
"""Exponential Moving Average (EMA) Weight Averaging callback."""
368+
369+
def __init__(
370+
self,
371+
device: Optional[Union[torch.device, str, int]] = None,
372+
use_buffers: bool = True,
373+
decay: float = 0.999,
374+
update_every_n_steps: int = 1,
375+
update_starting_at_step: Optional[int] = None,
376+
update_starting_at_epoch: Optional[int] = None,
377+
**kwargs: Any,
378+
):
379+
super().__init__(
380+
device=device,
381+
use_buffers=use_buffers,
382+
**kwargs,
383+
avg_fn=get_ema_avg_fn(decay=decay),
384+
)
385+
386+
self.update_every_n_steps = update_every_n_steps
387+
self.update_starting_at_step = update_starting_at_step
388+
self.update_starting_at_epoch = update_starting_at_epoch
389+
390+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
391+
"""Decide when to update the model weights.
392+
393+
Args:
394+
step_idx: The current step index.
395+
epoch_idx: The current epoch index.
396+
Returns:
397+
bool: True if the model weights should be updated, False otherwise.
398+
399+
"""
400+
if step_idx is not None:
401+
# Check step-based conditions only if we have a valid step_idx
402+
meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
403+
meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
404+
if meets_step_requirement and meets_step_frequency:
405+
return True
406+
407+
if epoch_idx is not None:
408+
# Check epoch-based condition only if we specify one
409+
meets_epoch_requirement = (
410+
self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch
411+
)
412+
if meets_epoch_requirement:
413+
return True
414+
415+
return False

src/lightning/pytorch/core/hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> int | None:
6969
"""Called in the training loop before anything happens for that batch.
7070
7171
If you return -1 here, you will skip training for the rest of the current epoch.
72+
Learning rate scheduler will still be stepped at the end of epoch.
7273
7374
Args:
7475
batch: The batched data as it is returned by the training DataLoader.

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -323,30 +323,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
323323
trainer._logger_connector.on_batch_start(batch)
324324

325325
batch_output: _OPTIMIZER_LOOP_OUTPUTS_TYPE | _MANUAL_LOOP_OUTPUTS_TYPE | None = None # for mypy
326+
should_skip_rest_of_epoch = False
327+
326328
if batch is None and not using_dataloader_iter:
327329
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
328330
else:
329331
# hook
330332
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
331333
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
332334
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
333-
if response == -1:
334-
self.batch_progress.increment_processed()
335-
raise StopIteration
336-
337-
self.batch_progress.increment_started()
338-
339-
kwargs = (
340-
self._build_kwargs(OrderedDict(), batch, batch_idx)
341-
if not using_dataloader_iter
342-
else OrderedDict(any=dataloader_iter)
343-
)
344-
with trainer.profiler.profile("run_training_batch"):
345-
if trainer.lightning_module.automatic_optimization:
346-
# in automatic optimization, there can only be one optimizer
347-
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
348-
else:
349-
batch_output = self.manual_optimization.run(kwargs)
335+
should_skip_rest_of_epoch = response == -1
336+
# Signal this is the last batch for the current epoch
337+
if should_skip_rest_of_epoch:
338+
self.batch_progress.increment_by(0, is_last_batch=True)
339+
else:
340+
self.batch_progress.increment_started()
341+
342+
kwargs = (
343+
self._build_kwargs(OrderedDict(), batch, batch_idx)
344+
if not using_dataloader_iter
345+
else OrderedDict(any=dataloader_iter)
346+
)
347+
with trainer.profiler.profile("run_training_batch"):
348+
if trainer.lightning_module.automatic_optimization:
349+
# in automatic optimization, there can only be one optimizer
350+
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
351+
else:
352+
batch_output = self.manual_optimization.run(kwargs)
350353

351354
self.batch_progress.increment_processed()
352355

@@ -356,6 +359,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
356359
if self._num_ready_batches_reached():
357360
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
358361

362+
if should_skip_rest_of_epoch:
363+
# Only raise StopIteration now so that the training epoch loop can finish
364+
raise StopIteration
365+
359366
if using_dataloader_iter:
360367
# update the hook kwargs now that the step method might have consumed the iterator
361368
batch = data_fetcher._batch

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
2929
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0")
3030
_TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9.0")
31+
_TORCH_EQUAL_2_9 = RequirementCache("torch>=2.9.0,<2.10.0")
3132

3233
_OMEGACONF_AVAILABLE = package_available("omegaconf")
3334
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")

0 commit comments

Comments
 (0)