@@ -361,3 +361,59 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
361
361
current_params = itertools .chain (pl_module .parameters (), pl_module .buffers ())
362
362
for average_param , current_param in zip (average_params , current_params ):
363
363
current_param .data .copy_ (average_param .data )
364
+
365
+
366
+ class EMAWeightAveraging (WeightAveraging ):
367
+ """Exponential Moving Average (EMA) Weight Averaging callback."""
368
+
369
+ def __init__ (
370
+ self ,
371
+ device : Optional [Union [torch .device , str , int ]] = None ,
372
+ use_buffers : bool = True ,
373
+ decay : float = 0.999 ,
374
+ update_every_n_steps : int = 1 ,
375
+ update_starting_at_step : Optional [int ] = None ,
376
+ update_starting_at_epoch : Optional [int ] = None ,
377
+ ** kwargs : Any ,
378
+ ):
379
+ super ().__init__ (
380
+ device = device ,
381
+ use_buffers = use_buffers ,
382
+ ** kwargs ,
383
+ avg_fn = get_ema_avg_fn (decay = decay ),
384
+ )
385
+
386
+ self .update_every_n_steps = update_every_n_steps
387
+ self .update_starting_at_step = update_starting_at_step
388
+ self .update_starting_at_epoch = update_starting_at_epoch
389
+
390
+ def should_update (self , step_idx : Optional [int ] = None , epoch_idx : Optional [int ] = None ):
391
+ """Decide when to update the model weights.
392
+
393
+ Args:
394
+ step_idx: The current step index.
395
+ epoch_idx: The current epoch index.
396
+ Returns:
397
+ bool: True if the model weights should be updated, False otherwise.
398
+ """
399
+ if step_idx is not None :
400
+ # Check step-based conditions only if we have a valid step_idx
401
+ meets_step_requirement = (
402
+ self .update_starting_at_step is None or step_idx >= self .update_starting_at_step
403
+ )
404
+ meets_step_frequency = (
405
+ self .update_every_n_steps > 0 and step_idx % self .update_every_n_steps == 0
406
+ )
407
+ if meets_step_requirement and meets_step_frequency :
408
+ return True
409
+
410
+ if epoch_idx is not None :
411
+ # Check epoch-based condition only if we specify one
412
+ meets_epoch_requirement = (
413
+ self .update_starting_at_epoch is not None
414
+ and epoch_idx >= self .update_starting_at_epoch
415
+ )
416
+ if meets_epoch_requirement :
417
+ return True
418
+
419
+ return False
0 commit comments