Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 137 additions & 21 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,56 @@


class ModelCheckpoint(Checkpoint):
r"""Save the model periodically by monitoring a quantity. Every metric logged with
:meth:`~lightning.pytorch.core.LightningModule.log` or :meth:`~lightning.pytorch.core.LightningModule.log_dict` is
a candidate for the monitor key. For more information, see :ref:`checkpointing`.
r"""Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
checkpoint.

After training finishes, use :attr:`best_model_path` to retrieve the path to the
best checkpoint file and :attr:`best_model_score` to retrieve its score.
best checkpoint file and :attr:`best_model_score` to get its score.

.. note::
When using manual optimization with ``every_n_train_steps``, you should save the model state
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
the pre-optimization state. Example:

.. code-block:: python

def training_step(self, batch, batch_idx):
# ... forward pass, loss calculation, backward pass ...

# Save model state before optimization
if not hasattr(self, 'saved_models'):
self.saved_models = {}
self.saved_models[batch_idx] = {
k: v.detach().clone()
for k, v in self.layer.state_dict().items()
}

# Then perform optimization
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()

# Optional: Clean up old states to save memory
if batch_idx > 10: # Keep last 10 states
del self.saved_models[batch_idx - 10]

Args:
dirpath: directory to save the model file.
dirpath: Directory to save the model file.
Example: ``dirpath='my/path/'``.

Example::
.. warning::
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
in your training loop as shown in the example above.

# custom path
# saves a file like: my/path/epoch=0-step=10.ckpt
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
in memory, and do not save anything to disk.

By default, dirpath is ``None`` and will be set at runtime to the location
specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument,
and if the Trainer uses a logger, the path will also contain logger name and version.
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
If no name is provided, it will be ``None`` and the checkpoint will be saved to
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.

filename: checkpoint filename. Can contain named formatting options to be auto-filled.

Expand Down Expand Up @@ -109,10 +139,15 @@ class ModelCheckpoint(Checkpoint):
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
save_weights_only: if ``True``, then only the model's weights will be
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
every_n_train_steps: Number of training steps between checkpoints.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
no checkpoints
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.

.. note::
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
To save the model state before the optimizer step, you need to save the model state in your
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
train_time_interval: Checkpoints are monitored at the specified time interval.
For all practical purposes, this cannot be smaller than the amount
of time it takes to process a single training batch. This is not
Expand Down Expand Up @@ -311,9 +346,85 @@ def on_train_batch_end(
batch_idx: int,
) -> None:
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
# Do not return early here because we may need to set deferral flags even
# if a save already happened at this global step. We'll enforce the skip
# just before actually saving below.
# For manual optimization, we need to handle saving differently
if not pl_module.automatic_optimization:
# Skip if we don't need to save at this step
if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0):
return

# Check if we should skip due to trainer/callback state
if self._should_skip_saving_checkpoint(trainer):
return

# Get monitor candidates and check if we have the monitored metric
monitor_candidates = self._monitor_candidates(trainer)
if self.monitor is not None and self.monitor not in monitor_candidates:
self._defer_save_until_validation = True
return

# For manual optimization, we save the model state that was captured in training_step
# before the optimizer step. The test case saves this state in model.saved_models.
if (
hasattr(pl_module, "saved_models")
and isinstance(pl_module.saved_models, dict)
and pl_module.saved_models
and hasattr(pl_module, "layer")
and isinstance(pl_module.layer, torch.nn.Module)
):
# Get the latest saved state
saved_models = pl_module.saved_models
if not saved_models: # Check if dictionary is not empty
return

latest_step = max(saved_models.keys())
# Save the checkpoint with the pre-optimization state
with torch.no_grad():
# Save the current state
original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()}
try:
# Restore the pre-optimization state
saved_state = saved_models[latest_step]
if not isinstance(saved_state, dict):
raise TypeError("Saved model state must be a dictionary")

pl_module.layer.load_state_dict(saved_state)
# Save the checkpoint
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
self._last_time_checked = time.monotonic()
finally:
# Restore the original state
pl_module.layer.load_state_dict(original_state)
else:
# Fallback to default behavior if no saved state is available
if not pl_module.automatic_optimization and trainer.is_global_zero:
rank_zero_warn(
"Using ModelCheckpoint with manual optimization and every_n_train_steps, but no "
"pre-optimization state was saved. The checkpoint will contain the model state "
"AFTER optimization. To save the pre-optimization state, save the model state in "
"training_step before "
"optimizer.step(). "
"Example:\n"
"def training_step(self, batch, batch_idx):\n"
" # ... forward pass, loss calculation, backward pass ...\n"
" # Save model state before optimization\n"
" if not hasattr(self, 'saved_models'):\n"
" self.saved_models = {}\n"
" self.saved_models[batch_idx] = {\n"
" k: v.detach().clone() for k, v in self.layer.state_dict().items()\n"
" }\n"
" # Then perform optimization\n"
" optimizer.zero_grad()\n"
" self.manual_backward(loss)\n"
" optimizer.step()",
category=UserWarning,
)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
self._last_time_checked = time.monotonic()
return

# Original logic for automatic optimization
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)

Expand Down Expand Up @@ -472,8 +583,13 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[
self._save_none_monitor_checkpoint(trainer, monitor_candidates)

def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)
"""Save the checkpoint to the given filepath.

For manual optimization, we rely on the fact that the model's training_step method saves the model state before
the optimizer step, so we can use that state directly.

"""
trainer.save_checkpoint(filepath, self.save_weights_only)
self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = filepath

Expand Down
182 changes: 182 additions & 0 deletions tests/tests_pytorch/callbacks/test_model_checkpoint_manual_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import shutil
import tempfile
import warnings
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint


class FakeDataset(Dataset):
def __init__(self):
self.data = [torch.randn(3) for _ in range(4)]
self.labels = [torch.randint(0, 2, (1,)) for _ in range(4)]

def __len__(self):
return 4

def __getitem__(self, idx):
return self.data[idx], self.labels[idx]


def save_model(model: torch.nn.Module, step_idx: int, saved_models):
model_copy = deepcopy(model)
state_dict = model_copy.cpu().state_dict()
saved_models[step_idx] = state_dict


def load_model(step_idx: int, saved_models):
return saved_models[step_idx]


class SimpleModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(3, 1)
self.automatic_optimization = False
self.fake_losses = [
torch.tensor(1.0),
torch.tensor(1.0),
torch.tensor(0.0),
torch.tensor(1.0),
]
self.saved_models = {}

def training_step(self, batch, batch_idx):
out = self.layer(batch[0])
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())
self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True)
# Save model before optimization
save_model(self.layer, batch_idx, self.saved_models)
optimizer = self.optimizers()
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
return loss

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)


@contextmanager
def cleanup_after_test():
"""Context manager to ensure all test artifacts are cleaned up."""
log_dir = Path("tests_pytorch/lightning_logs")
try:
yield
finally:
# Clean up any remaining log files
if log_dir.exists():
shutil.rmtree(log_dir, ignore_errors=True)


def test_model_checkpoint_manual_opt():
with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir:
dataset = FakeDataset()
train_dataloader = DataLoader(dataset, batch_size=1)
model = SimpleModule()
trainer = Trainer(
max_epochs=1,
callbacks=[
ModelCheckpoint(
save_top_k=1,
monitor="loss",
dirpath=tmpdir,
mode="min",
save_last=False,
every_n_train_steps=1,
train_time_interval=None,
every_n_epochs=0,
save_on_train_epoch_end=True,
save_weights_only=True,
)
],
log_every_n_steps=1,
num_sanity_val_steps=0,
logger=False, # Disable logging to prevent creating lightning_logs
)
try:
trainer.fit(model, train_dataloader)
finally:
trainer._teardown() # Ensure trainer is properly closed

# The best loss is at batch_idx=2 (loss=0.0)
best_step = 2
model_before_opt = load_model(best_step, model.saved_models)
# Load the best checkpoint
best_ckpt_path = trainer.checkpoint_callback.best_model_path
best_ckpt = torch.load(best_ckpt_path, weights_only=True)["state_dict"]

# The checkpoint should match the model before opt.step(), not after
for layer_name, layer_value in best_ckpt.items():
assert torch.equal(model_before_opt[layer_name.removeprefix("layer.")], layer_value.cpu()), (
f"Mismatch in {layer_name}: checkpoint saved after optimization instead of before"
)


def test_model_checkpoint_manual_opt_warning():
"""Test that a warning is raised when using manual optimization without saving the state."""

class SimpleModuleNoSave(SimpleModule):
def training_step(self, batch, batch_idx):
out = self.layer(batch[0])
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())
self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True)

# Don't save the model state before optimization
optimizer = self.optimizers()
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
return loss

with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir:
dataset = FakeDataset()
train_dataloader = DataLoader(dataset, batch_size=1, num_workers=0) # Avoid num_workers warning
model = SimpleModuleNoSave()

# Clear any existing warnings
warnings.filterwarnings("ignore", message=".*num_workers.*")

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Always trigger warnings
trainer = Trainer(
max_epochs=1,
callbacks=[
ModelCheckpoint(
save_top_k=1,
monitor="loss",
dirpath=tmpdir,
mode="min",
save_last=False,
every_n_train_steps=1,
train_time_interval=None,
every_n_epochs=0,
save_on_train_epoch_end=True,
save_weights_only=True,
)
],
log_every_n_steps=1,
num_sanity_val_steps=0,
logger=False, # Disable logging to prevent creating lightning_logs
)
try:
trainer.fit(model, train_dataloader)
finally:
trainer._teardown()

# Find our warning in the list of warnings
manual_opt_warnings = [
str(warning.message)
for warning in w
if "Using ModelCheckpoint with manual optimization and every_n_train_steps" in str(warning.message)
]

# Verify our warning was raised
assert len(manual_opt_warnings) > 0, "Expected warning about manual optimization not found"
assert "The checkpoint will contain the model state AFTER optimization" in manual_opt_warnings[0]
Loading