You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Generic weight averaging callback that supports EMA (#20545)
* Weight averaging callback
* A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs.
* The user can provide a callback that defines after which steps or epochs the average model is updated.
* More generic customization of the WeightAveraging callback
- The user can specify when to update the average model by overriding the should_update() method
- Any keyword arguments will be passed to the AveragedModel constructor
* Training tricks mentions WeightAveraging and EMA
* Removed logging from WeightAveraging
* Fixed the documentation
* Fixed checkpoint loading with WeightAveraging
* WeightAveraging calls the configure_model hook but issues a warning
* Fixed a reference in a docstring.
* Removed two unit tests to avoid running out of memory in the CI pipeline.
* The default device for the averaged model is the device of the original model
* Added seealso to WeightAveraging and StochasticWeightAveraging
* More verbose description of WeightAveraging
* Describe the magic number 7 in a comment
* Update src/lightning/pytorch/CHANGELOG.md
---------
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: Seppo Enarvi <[email protected]>
Co-authored-by: Seppo Enarvi <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Copy file name to clipboardExpand all lines: src/lightning/pytorch/CHANGELOG.md
+4-1Lines changed: 4 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -10,7 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
10
10
11
11
### Added
12
12
13
-
- Added Torch-Tensorrt Integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
13
+
- Added `WeightAveraging` callback that wraps the PyTorch `AveragedModel` class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))
14
+
15
+
16
+
- Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
0 commit comments