diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1bba5e4ca0da7..30072517a2e9a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) +- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index f9b8d64eae6a5..c6f95adaedc1a 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -21,7 +21,7 @@ from typing import Any, Optional, Union import torch -from torch.optim.swa_utils import AveragedModel +from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn from typing_extensions import override import lightning.pytorch as pl @@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) for average_param, current_param in zip(average_params, current_params): current_param.data.copy_(average_param.data) + + +class EMAWeightAveraging(WeightAveraging): + """Exponential Moving Average (EMA) Weight Averaging callback.""" + + def __init__( + self, + device: Optional[Union[torch.device, str, int]] = None, + use_buffers: bool = True, + decay: float = 0.999, + update_every_n_steps: int = 1, + update_starting_at_step: Optional[int] = None, + update_starting_at_epoch: Optional[int] = None, + **kwargs: Any, + ): + super().__init__( + device=device, + use_buffers=use_buffers, + **kwargs, + avg_fn=get_ema_avg_fn(decay=decay), + ) + + self.update_every_n_steps = update_every_n_steps + self.update_starting_at_step = update_starting_at_step + self.update_starting_at_epoch = update_starting_at_epoch + + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None): + """Decide when to update the model weights. + + Args: + step_idx: The current step index. + epoch_idx: The current epoch index. + Returns: + bool: True if the model weights should be updated, False otherwise. + + """ + if step_idx is not None: + # Check step-based conditions only if we have a valid step_idx + meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step + meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 + if meets_step_requirement and meets_step_frequency: + return True + + if epoch_idx is not None: + # Check epoch-based condition only if we specify one + meets_epoch_requirement = ( + self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch + ) + if meets_epoch_requirement: + return True + + return False