Skip to content

Commit 6415599

Browse files
authored
Add EMAWeightAveraging callback to weight_averaging.py
1 parent e21b172 commit 6415599

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,59 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
361361
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
362362
for average_param, current_param in zip(average_params, current_params):
363363
current_param.data.copy_(average_param.data)
364+
365+
366+
class EMAWeightAveraging(WeightAveraging):
367+
"""Exponential Moving Average (EMA) Weight Averaging callback."""
368+
369+
def __init__(
370+
self,
371+
device: Optional[Union[torch.device, str, int]] = None,
372+
use_buffers: bool = True,
373+
decay: float = 0.999,
374+
update_every_n_steps: int = 1,
375+
update_starting_at_step: Optional[int] = None,
376+
update_starting_at_epoch: Optional[int] = None,
377+
**kwargs: Any,
378+
):
379+
super().__init__(
380+
device=device,
381+
use_buffers=use_buffers,
382+
**kwargs,
383+
avg_fn=get_ema_avg_fn(decay=decay),
384+
)
385+
386+
self.update_every_n_steps = update_every_n_steps
387+
self.update_starting_at_step = update_starting_at_step
388+
self.update_starting_at_epoch = update_starting_at_epoch
389+
390+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None):
391+
"""Decide when to update the model weights.
392+
393+
Args:
394+
step_idx: The current step index.
395+
epoch_idx: The current epoch index.
396+
Returns:
397+
bool: True if the model weights should be updated, False otherwise.
398+
"""
399+
if step_idx is not None:
400+
# Check step-based conditions only if we have a valid step_idx
401+
meets_step_requirement = (
402+
self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
403+
)
404+
meets_step_frequency = (
405+
self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
406+
)
407+
if meets_step_requirement and meets_step_frequency:
408+
return True
409+
410+
if epoch_idx is not None:
411+
# Check epoch-based condition only if we specify one
412+
meets_epoch_requirement = (
413+
self.update_starting_at_epoch is not None
414+
and epoch_idx >= self.update_starting_at_epoch
415+
)
416+
if meets_epoch_requirement:
417+
return True
418+
419+
return False

0 commit comments

Comments
 (0)