Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions nequip/train/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,24 @@ def optimizer_step(
# The purpose is to take advantage of the compiled state of the base model (torchscript or torch.compile)
# the purpose of the assert is to safeguard against unexpected load scenarios where
# `self.model` holds the EMA weights and `self.ema` holds the raw weights (we don't expect this situation to happen, but just in case ...)
# we assume that checkpoints will not be saved in between the start and end of `val`, `test`, `predict`
# which should be true since Lightning's `ModelCheckpoint` only has a `on_validation_end` hook (no `on_validation_epoch_end` for example), i.e. the hooks we use for the switching should be sufficient for correct behavior

def on_save_checkpoint(self, checkpoint):
"""Ensure checkpoint is always saved with EMA module holding EMA weights.

Lightning's ``ModelCheckpoint`` callback saves during ``on_validation_end``,
which runs *before* ``LightningModule.on_validation_end``. This means the
checkpoint can be captured while the EMA weights are swapped into the model
(i.e. during validation). We correct for this by swapping back to the normal
state, re-capturing the state dict, and then restoring the validation state.
"""
if not self.ema.is_holding_ema_weights:
# We're in the swapped state (model holds EMA weights, ema holds raw weights).
# Swap back to normal state for a consistent checkpoint.
self.ema.swap_parameters(self.model)
checkpoint["state_dict"] = self.state_dict()
# Restore the validation state so validation can finish correctly.
self.ema.swap_parameters(self.model)

def _assert_ema_status_and_switch(
self, expect_ema_module_holds_ema_weights: bool, evaluation_mode: str
):
Expand Down Expand Up @@ -97,6 +113,17 @@ def evaluation_model(self) -> torch.nn.Module:
# === load up EMA weights ===
# logging for sanity checking, especially useful for diamond inheritance subclasses involving EMA
logger.info("Loading EMA weights for evaluation model.")
if getattr(self.ema, "_needs_post_load_swap", False):
# Checkpoint was saved in swapped state (old nequip without on_save_checkpoint fix).
# Model params already hold EMA weights, EMA buffers hold raw weights.
# Swap to normalize the state, then swap again to load EMA weights into model.
logger.warning(
"Checkpoint was saved with swapped EMA state. "
"Correcting — consider re-training with the latest nequip version."
)
# Swap to normalize: model gets raw weights, EMA gets EMA weights
self.ema.swap_parameters(self.model)
self.ema._needs_post_load_swap = False
# we expect `self.model` to contain the raw weights
self._assert_ema_status_and_switch(True, "loading for evaluation")
return self.model
Expand Down Expand Up @@ -225,9 +252,14 @@ def set_extra_state(self, state):
""""""
self.num_updates = state["num_updates"]
self.is_holding_ema_weights = state["is_holding_ema_weights"]
assert self.is_holding_ema_weights, (
"EMA module loaded in a state where it does not contain EMA weights -- the checkpoint file is likely corrupted."
)
if not self.is_holding_ema_weights:
# This can happen with checkpoints saved by older versions of nequip
# where on_save_checkpoint did not correct the EMA state. The checkpoint
# data is valid — EMA weights are in model params and raw weights are in
# EMA buffers — so we flag this for the parent module to handle the swap.
self._needs_post_load_swap = True
else:
self._needs_post_load_swap = False

# handle possibility of restarts overwriting `decay`
state_dict_decay = state["decay"]
Expand Down
Loading