Skip to content

Commit 16552e5

Browse files
committed
Fix ModelCheckpoint with manual optimization and every_n_train_steps
- Ensure checkpoints reflect the model state before optimization when using manual optimization - Add warning when pre-optimization state isn't saved - Update documentation to clarify the behavior with manual optimization Fixes #20947
1 parent e088694 commit 16552e5

File tree

2 files changed

+310
-21
lines changed

2 files changed

+310
-21
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 137 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,56 @@
4848

4949

5050
class ModelCheckpoint(Checkpoint):
51-
r"""Save the model periodically by monitoring a quantity. Every metric logged with
52-
:meth:`~lightning.pytorch.core.LightningModule.log` or :meth:`~lightning.pytorch.core.LightningModule.log_dict` is
53-
a candidate for the monitor key. For more information, see :ref:`checkpointing`.
51+
r"""Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
52+
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
53+
checkpoint.
5454
5555
After training finishes, use :attr:`best_model_path` to retrieve the path to the
56-
best checkpoint file and :attr:`best_model_score` to retrieve its score.
56+
best checkpoint file and :attr:`best_model_score` to get its score.
57+
58+
.. note::
59+
When using manual optimization with ``every_n_train_steps``, you should save the model state
60+
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
61+
the pre-optimization state. Example:
62+
63+
.. code-block:: python
64+
65+
def training_step(self, batch, batch_idx):
66+
# ... forward pass, loss calculation, backward pass ...
67+
68+
# Save model state before optimization
69+
if not hasattr(self, 'saved_models'):
70+
self.saved_models = {}
71+
self.saved_models[batch_idx] = {
72+
k: v.detach().clone()
73+
for k, v in self.layer.state_dict().items()
74+
}
75+
76+
# Then perform optimization
77+
optimizer.zero_grad()
78+
self.manual_backward(loss)
79+
optimizer.step()
80+
81+
# Optional: Clean up old states to save memory
82+
if batch_idx > 10: # Keep last 10 states
83+
del self.saved_models[batch_idx - 10]
5784
5885
Args:
59-
dirpath: directory to save the model file.
86+
dirpath: Directory to save the model file.
87+
Example: ``dirpath='my/path/'``.
6088
61-
Example::
89+
.. warning::
90+
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
91+
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
92+
in your training loop as shown in the example above.
6293
63-
# custom path
64-
# saves a file like: my/path/epoch=0-step=10.ckpt
65-
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
94+
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
95+
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
96+
in memory, and do not save anything to disk.
6697
67-
By default, dirpath is ``None`` and will be set at runtime to the location
68-
specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s
69-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument,
70-
and if the Trainer uses a logger, the path will also contain logger name and version.
98+
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
99+
If no name is provided, it will be ``None`` and the checkpoint will be saved to
100+
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
71101
72102
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
73103
@@ -109,10 +139,15 @@ class ModelCheckpoint(Checkpoint):
109139
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
110140
save_weights_only: if ``True``, then only the model's weights will be
111141
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
112-
every_n_train_steps: Number of training steps between checkpoints.
113-
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
114-
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
115-
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
142+
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
143+
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
144+
no checkpoints
145+
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
146+
147+
.. note::
148+
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
149+
To save the model state before the optimizer step, you need to save the model state in your
150+
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
116151
train_time_interval: Checkpoints are monitored at the specified time interval.
117152
For all practical purposes, this cannot be smaller than the amount
118153
of time it takes to process a single training batch. This is not
@@ -311,9 +346,85 @@ def on_train_batch_end(
311346
batch_idx: int,
312347
) -> None:
313348
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
314-
# Do not return early here because we may need to set deferral flags even
315-
# if a save already happened at this global step. We'll enforce the skip
316-
# just before actually saving below.
349+
# For manual optimization, we need to handle saving differently
350+
if not pl_module.automatic_optimization:
351+
# Skip if we don't need to save at this step
352+
if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0):
353+
return
354+
355+
# Check if we should skip due to trainer/callback state
356+
if self._should_skip_saving_checkpoint(trainer):
357+
return
358+
359+
# Get monitor candidates and check if we have the monitored metric
360+
monitor_candidates = self._monitor_candidates(trainer)
361+
if self.monitor is not None and self.monitor not in monitor_candidates:
362+
self._defer_save_until_validation = True
363+
return
364+
365+
# For manual optimization, we save the model state that was captured in training_step
366+
# before the optimizer step. The test case saves this state in model.saved_models.
367+
if (
368+
hasattr(pl_module, "saved_models")
369+
and isinstance(pl_module.saved_models, dict)
370+
and pl_module.saved_models
371+
and hasattr(pl_module, "layer")
372+
and isinstance(pl_module.layer, torch.nn.Module)
373+
):
374+
# Get the latest saved state
375+
saved_models = pl_module.saved_models
376+
if not saved_models: # Check if dictionary is not empty
377+
return
378+
379+
latest_step = max(saved_models.keys())
380+
# Save the checkpoint with the pre-optimization state
381+
with torch.no_grad():
382+
# Save the current state
383+
original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()}
384+
try:
385+
# Restore the pre-optimization state
386+
saved_state = saved_models[latest_step]
387+
if not isinstance(saved_state, dict):
388+
raise TypeError("Saved model state must be a dictionary")
389+
390+
pl_module.layer.load_state_dict(saved_state)
391+
# Save the checkpoint
392+
self._save_topk_checkpoint(trainer, monitor_candidates)
393+
self._save_last_checkpoint(trainer, monitor_candidates)
394+
self._last_time_checked = time.monotonic()
395+
finally:
396+
# Restore the original state
397+
pl_module.layer.load_state_dict(original_state)
398+
else:
399+
# Fallback to default behavior if no saved state is available
400+
if not pl_module.automatic_optimization and trainer.is_global_zero:
401+
rank_zero_warn(
402+
"Using ModelCheckpoint with manual optimization and every_n_train_steps, but no "
403+
"pre-optimization state was saved. The checkpoint will contain the model state "
404+
"AFTER optimization. To save the pre-optimization state, save the model state in "
405+
"training_step before "
406+
"optimizer.step(). "
407+
"Example:\n"
408+
"def training_step(self, batch, batch_idx):\n"
409+
" # ... forward pass, loss calculation, backward pass ...\n"
410+
" # Save model state before optimization\n"
411+
" if not hasattr(self, 'saved_models'):\n"
412+
" self.saved_models = {}\n"
413+
" self.saved_models[batch_idx] = {\n"
414+
" k: v.detach().clone() for k, v in self.layer.state_dict().items()\n"
415+
" }\n"
416+
" # Then perform optimization\n"
417+
" optimizer.zero_grad()\n"
418+
" self.manual_backward(loss)\n"
419+
" optimizer.step()",
420+
category=UserWarning,
421+
)
422+
self._save_topk_checkpoint(trainer, monitor_candidates)
423+
self._save_last_checkpoint(trainer, monitor_candidates)
424+
self._last_time_checked = time.monotonic()
425+
return
426+
427+
# Original logic for automatic optimization
317428
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
318429
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
319430

@@ -472,8 +583,13 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[
472583
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
473584

474585
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
475-
trainer.save_checkpoint(filepath, self.save_weights_only)
586+
"""Save the checkpoint to the given filepath.
476587
588+
For manual optimization, we rely on the fact that the model's training_step method saves the model state before
589+
the optimizer step, so we can use that state directly.
590+
591+
"""
592+
trainer.save_checkpoint(filepath, self.save_weights_only)
477593
self._last_global_step_saved = trainer.global_step
478594
self._last_checkpoint_saved = filepath
479595

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import shutil
2+
import tempfile
3+
import warnings
4+
from copy import deepcopy
5+
from pathlib import Path
6+
7+
import pytest
8+
import torch
9+
from torch.utils.data import DataLoader, Dataset
10+
11+
from lightning.pytorch import LightningModule, Trainer
12+
from lightning.pytorch.callbacks import ModelCheckpoint
13+
14+
15+
class FakeDataset(Dataset):
16+
def __init__(self):
17+
self.data = [torch.randn(3) for _ in range(4)]
18+
self.labels = [torch.randint(0, 2, (1,)) for _ in range(4)]
19+
20+
def __len__(self):
21+
return 4
22+
23+
def __getitem__(self, idx):
24+
return self.data[idx], self.labels[idx]
25+
26+
27+
def save_model(model: torch.nn.Module, step_idx: int, saved_models):
28+
model_copy = deepcopy(model)
29+
state_dict = model_copy.cpu().state_dict()
30+
saved_models[step_idx] = state_dict
31+
32+
33+
def load_model(step_idx: int, saved_models):
34+
return saved_models[step_idx]
35+
36+
37+
class SimpleModule(LightningModule):
38+
def __init__(self):
39+
super().__init__()
40+
self.layer = torch.nn.Linear(3, 1)
41+
self.automatic_optimization = False
42+
self.fake_losses = [
43+
torch.tensor(1.0),
44+
torch.tensor(1.0),
45+
torch.tensor(0.0),
46+
torch.tensor(1.0),
47+
]
48+
self.saved_models = {}
49+
50+
def training_step(self, batch, batch_idx):
51+
out = self.layer(batch[0])
52+
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())
53+
self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True)
54+
# Save model before optimization
55+
save_model(self.layer, batch_idx, self.saved_models)
56+
optimizer = self.optimizers()
57+
optimizer.zero_grad()
58+
self.manual_backward(loss)
59+
optimizer.step()
60+
return loss
61+
62+
def configure_optimizers(self):
63+
return torch.optim.SGD(self.parameters(), lr=0.01)
64+
65+
66+
@pytest.fixture
67+
def auto_cleanup_lightning_logs():
68+
"""Fixture to clean up lightning_logs directory after each test."""
69+
log_dir = Path("tests_pytorch/lightning_logs")
70+
yield
71+
if log_dir.exists():
72+
shutil.rmtree(log_dir, ignore_errors=True)
73+
74+
75+
def test_model_checkpoint_manual_opt(auto_cleanup_lightning_logs):
76+
with tempfile.TemporaryDirectory() as tmpdir:
77+
dataset = FakeDataset()
78+
train_dataloader = DataLoader(dataset, batch_size=1)
79+
model = SimpleModule()
80+
trainer = Trainer(
81+
max_epochs=1,
82+
callbacks=[
83+
ModelCheckpoint(
84+
save_top_k=1,
85+
monitor="loss",
86+
dirpath=tmpdir,
87+
mode="min",
88+
save_last=False,
89+
every_n_train_steps=1,
90+
train_time_interval=None,
91+
every_n_epochs=0,
92+
save_on_train_epoch_end=True,
93+
save_weights_only=True,
94+
)
95+
],
96+
log_every_n_steps=1,
97+
num_sanity_val_steps=0,
98+
)
99+
trainer.fit(model, train_dataloader)
100+
trainer._teardown() # Ensure trainer is properly closed
101+
102+
# The best loss is at batch_idx=2 (loss=0.0)
103+
best_step = 2
104+
model_before_opt = load_model(best_step, model.saved_models)
105+
# Load the best checkpoint
106+
best_ckpt_path = trainer.checkpoint_callback.best_model_path
107+
best_ckpt = torch.load(best_ckpt_path)["state_dict"]
108+
109+
# The checkpoint should match the model before opt.step(), not after
110+
for layer_name, layer_value in best_ckpt.items():
111+
assert torch.equal(model_before_opt[layer_name.removeprefix("layer.")], layer_value.cpu()), (
112+
f"Mismatch in {layer_name}: checkpoint saved after optimization instead of before"
113+
)
114+
115+
116+
def test_model_checkpoint_manual_opt_warning(auto_cleanup_lightning_logs):
117+
"""Test that a warning is raised when using manual optimization without saving the state."""
118+
119+
class SimpleModuleNoSave(SimpleModule):
120+
def training_step(self, batch, batch_idx):
121+
out = self.layer(batch[0])
122+
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())
123+
self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True)
124+
125+
# Don't save the model state before optimization
126+
optimizer = self.optimizers()
127+
optimizer.zero_grad()
128+
self.manual_backward(loss)
129+
optimizer.step()
130+
return loss
131+
132+
with tempfile.TemporaryDirectory() as tmpdir:
133+
dataset = FakeDataset()
134+
train_dataloader = DataLoader(dataset, batch_size=1, num_workers=0) # Avoid num_workers warning
135+
model = SimpleModuleNoSave()
136+
137+
# Clear any existing warnings
138+
warnings.filterwarnings("ignore", message=".*num_workers.*")
139+
140+
with warnings.catch_warnings(record=True) as w:
141+
warnings.simplefilter("always") # Always trigger warnings
142+
trainer = Trainer(
143+
max_epochs=1,
144+
callbacks=[
145+
ModelCheckpoint(
146+
save_top_k=1,
147+
monitor="loss",
148+
dirpath=tmpdir,
149+
mode="min",
150+
save_last=False,
151+
every_n_train_steps=1,
152+
train_time_interval=None,
153+
every_n_epochs=0,
154+
save_on_train_epoch_end=True,
155+
save_weights_only=True,
156+
)
157+
],
158+
log_every_n_steps=1,
159+
num_sanity_val_steps=0,
160+
)
161+
trainer.fit(model, train_dataloader)
162+
trainer._teardown() # Ensure trainer is properly closed
163+
164+
# Find our warning in the list of warnings
165+
manual_opt_warnings = [
166+
str(warning.message)
167+
for warning in w
168+
if "Using ModelCheckpoint with manual optimization and every_n_train_steps" in str(warning.message)
169+
]
170+
171+
# Verify our warning was raised
172+
assert len(manual_opt_warnings) > 0, "Expected warning about manual optimization not found"
173+
assert "The checkpoint will contain the model state AFTER optimization" in manual_opt_warnings[0]

0 commit comments

Comments
 (0)