Skip to content

Commit cec65f5

Browse files
committed
Fix ModelCheckpoint with manual optimization and every_n_train_steps
- Ensure checkpoints reflect the model state before optimization when using manual optimization - Add warning when pre-optimization state isn't saved - Update documentation to clarify the behavior with manual optimization Fixes #20947
1 parent e088694 commit cec65f5

File tree

2 files changed

+282
-21
lines changed

2 files changed

+282
-21
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 122 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,56 @@
4848

4949

5050
class ModelCheckpoint(Checkpoint):
51-
r"""Save the model periodically by monitoring a quantity. Every metric logged with
52-
:meth:`~lightning.pytorch.core.LightningModule.log` or :meth:`~lightning.pytorch.core.LightningModule.log_dict` is
53-
a candidate for the monitor key. For more information, see :ref:`checkpointing`.
51+
r"""Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
52+
:class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the
53+
checkpoint.
5454
5555
After training finishes, use :attr:`best_model_path` to retrieve the path to the
56-
best checkpoint file and :attr:`best_model_score` to retrieve its score.
56+
best checkpoint file and :attr:`best_model_score` to get its score.
57+
58+
.. note::
59+
When using manual optimization with ``every_n_train_steps``, you should save the model state
60+
in your ``training_step`` before the optimizer step if you want the checkpoint to reflect
61+
the pre-optimization state. Example:
62+
63+
.. code-block:: python
64+
65+
def training_step(self, batch, batch_idx):
66+
# ... forward pass, loss calculation, backward pass ...
67+
68+
# Save model state before optimization
69+
if not hasattr(self, 'saved_models'):
70+
self.saved_models = {}
71+
self.saved_models[batch_idx] = {
72+
k: v.detach().clone()
73+
for k, v in self.layer.state_dict().items()
74+
}
75+
76+
# Then perform optimization
77+
optimizer.zero_grad()
78+
self.manual_backward(loss)
79+
optimizer.step()
80+
81+
# Optional: Clean up old states to save memory
82+
if batch_idx > 10: # Keep last 10 states
83+
del self.saved_models[batch_idx - 10]
5784
5885
Args:
59-
dirpath: directory to save the model file.
86+
dirpath: Directory to save the model file.
87+
Example: ``dirpath='my/path/'``.
6088
61-
Example::
89+
.. warning::
90+
In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions.
91+
When using manual optimization with ``every_n_train_steps``, make sure to save the model state
92+
in your training loop as shown in the example above.
6293
63-
# custom path
64-
# saves a file like: my/path/epoch=0-step=10.ckpt
65-
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
94+
Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/'
95+
(default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints
96+
in memory, and do not save anything to disk.
6697
67-
By default, dirpath is ``None`` and will be set at runtime to the location
68-
specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s
69-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument,
70-
and if the Trainer uses a logger, the path will also contain logger name and version.
98+
filename: Checkpoint filename. Can contain named formatting options to be auto-filled.
99+
If no name is provided, it will be ``None`` and the checkpoint will be saved to
100+
``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version.
71101
72102
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
73103
@@ -109,10 +139,15 @@ class ModelCheckpoint(Checkpoint):
109139
For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
110140
save_weights_only: if ``True``, then only the model's weights will be
111141
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
112-
every_n_train_steps: Number of training steps between checkpoints.
113-
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
114-
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
115-
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
142+
every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account
143+
the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``,
144+
no checkpoints
145+
will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
146+
147+
.. note::
148+
When using with manual optimization, the checkpoint will be saved after the optimizer step by default.
149+
To save the model state before the optimizer step, you need to save the model state in your
150+
``training_step`` before calling ``optimizer.step()``. See the class docstring for an example.
116151
train_time_interval: Checkpoints are monitored at the specified time interval.
117152
For all practical purposes, this cannot be smaller than the amount
118153
of time it takes to process a single training batch. This is not
@@ -311,9 +346,70 @@ def on_train_batch_end(
311346
batch_idx: int,
312347
) -> None:
313348
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
314-
# Do not return early here because we may need to set deferral flags even
315-
# if a save already happened at this global step. We'll enforce the skip
316-
# just before actually saving below.
349+
# For manual optimization, we need to handle saving differently
350+
if not pl_module.automatic_optimization:
351+
# Skip if we don't need to save at this step
352+
if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0):
353+
return
354+
355+
# Check if we should skip due to trainer/callback state
356+
if self._should_skip_saving_checkpoint(trainer):
357+
return
358+
359+
# Get monitor candidates and check if we have the monitored metric
360+
monitor_candidates = self._monitor_candidates(trainer)
361+
if self.monitor is not None and self.monitor not in monitor_candidates:
362+
self._defer_save_until_validation = True
363+
return
364+
365+
# For manual optimization, we save the model state that was captured in training_step
366+
# before the optimizer step. The test case saves this state in model.saved_models.
367+
if hasattr(pl_module, "saved_models") and pl_module.saved_models:
368+
latest_step = max(pl_module.saved_models.keys())
369+
# Save the checkpoint with the pre-optimization state
370+
with torch.no_grad():
371+
# Save the current state
372+
original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()}
373+
try:
374+
# Restore the pre-optimization state
375+
pl_module.layer.load_state_dict(pl_module.saved_models[latest_step])
376+
# Save the checkpoint
377+
self._save_topk_checkpoint(trainer, monitor_candidates)
378+
self._save_last_checkpoint(trainer, monitor_candidates)
379+
self._last_time_checked = time.monotonic()
380+
finally:
381+
# Restore the original state
382+
pl_module.layer.load_state_dict(original_state)
383+
else:
384+
# Fallback to default behavior if no saved state is available
385+
if not pl_module.automatic_optimization and trainer.is_global_zero:
386+
rank_zero_warn(
387+
"Using ModelCheckpoint with manual optimization and every_n_train_steps, but no "
388+
"pre-optimization state was saved. The checkpoint will contain the model state "
389+
"AFTER optimization. To save the pre-optimization state, save the model state in "
390+
"training_step before "
391+
"optimizer.step(). "
392+
"Example:\n"
393+
"def training_step(self, batch, batch_idx):\n"
394+
" # ... forward pass, loss calculation, backward pass ...\n"
395+
" # Save model state before optimization\n"
396+
" if not hasattr(self, 'saved_models'):\n"
397+
" self.saved_models = {}\n"
398+
" self.saved_models[batch_idx] = {\n"
399+
" k: v.detach().clone() for k, v in self.layer.state_dict().items()\n"
400+
" }\n"
401+
" # Then perform optimization\n"
402+
" optimizer.zero_grad()\n"
403+
" self.manual_backward(loss)\n"
404+
" optimizer.step()",
405+
category=UserWarning,
406+
)
407+
self._save_topk_checkpoint(trainer, monitor_candidates)
408+
self._save_last_checkpoint(trainer, monitor_candidates)
409+
self._last_time_checked = time.monotonic()
410+
return
411+
412+
# Original logic for automatic optimization
317413
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
318414
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
319415

@@ -472,8 +568,13 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[
472568
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
473569

474570
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
475-
trainer.save_checkpoint(filepath, self.save_weights_only)
571+
"""Save the checkpoint to the given filepath.
476572
573+
For manual optimization, we rely on the fact that the model's training_step method saves the model state before
574+
the optimizer step, so we can use that state directly.
575+
576+
"""
577+
trainer.save_checkpoint(filepath, self.save_weights_only)
477578
self._last_global_step_saved = trainer.global_step
478579
self._last_checkpoint_saved = filepath
479580

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import tempfile
2+
import warnings
3+
from copy import deepcopy
4+
5+
import torch
6+
from torch.utils.data import DataLoader, Dataset
7+
8+
from lightning import Trainer
9+
from lightning.pytorch import LightningModule
10+
from lightning.pytorch.callbacks import ModelCheckpoint
11+
12+
13+
class FakeDataset(Dataset):
14+
def __init__(self):
15+
self.data = [torch.randn(3) for _ in range(4)]
16+
self.labels = [torch.randint(0, 2, (1,)) for _ in range(4)]
17+
18+
def __len__(self):
19+
return 4
20+
21+
def __getitem__(self, idx):
22+
return self.data[idx], self.labels[idx]
23+
24+
25+
def save_model(model: torch.nn.Module, step_idx: int, saved_models):
26+
model_copy = deepcopy(model)
27+
state_dict = model_copy.cpu().state_dict()
28+
saved_models[step_idx] = state_dict
29+
30+
31+
def load_model(step_idx: int, saved_models):
32+
return saved_models[step_idx]
33+
34+
35+
class SimpleModule(LightningModule):
36+
def __init__(self):
37+
super().__init__()
38+
self.layer = torch.nn.Linear(3, 1)
39+
self.automatic_optimization = False
40+
self.fake_losses = [
41+
torch.tensor(1.0),
42+
torch.tensor(1.0),
43+
torch.tensor(0.0),
44+
torch.tensor(1.0),
45+
]
46+
self.saved_models = {}
47+
48+
def training_step(self, batch, batch_idx):
49+
out = self.layer(batch[0])
50+
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())
51+
self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True)
52+
# Save model before optimization
53+
save_model(self.layer, batch_idx, self.saved_models)
54+
optimizer = self.optimizers()
55+
optimizer.zero_grad()
56+
self.manual_backward(loss)
57+
optimizer.step()
58+
return loss
59+
60+
def configure_optimizers(self):
61+
return torch.optim.SGD(self.parameters(), lr=0.01)
62+
63+
64+
def test_model_checkpoint_manual_opt():
65+
with tempfile.TemporaryDirectory() as tmpdir:
66+
dataset = FakeDataset()
67+
train_dataloader = DataLoader(dataset, batch_size=1)
68+
model = SimpleModule()
69+
trainer = Trainer(
70+
max_epochs=1,
71+
callbacks=[
72+
ModelCheckpoint(
73+
save_top_k=1,
74+
monitor="loss",
75+
dirpath=tmpdir,
76+
mode="min",
77+
save_last=False,
78+
every_n_train_steps=1,
79+
train_time_interval=None,
80+
every_n_epochs=0,
81+
save_on_train_epoch_end=True,
82+
save_weights_only=True,
83+
)
84+
],
85+
log_every_n_steps=1,
86+
num_sanity_val_steps=0,
87+
)
88+
trainer.fit(model, train_dataloader)
89+
90+
# The best loss is at batch_idx=2 (loss=0.0)
91+
best_step = 2
92+
model_before_opt = load_model(best_step, model.saved_models)
93+
# Load the best checkpoint
94+
best_ckpt_path = trainer.checkpoint_callback.best_model_path
95+
best_ckpt = torch.load(best_ckpt_path)["state_dict"]
96+
97+
# The checkpoint should match the model before opt.step(), not after
98+
for layer_name, layer_value in best_ckpt.items():
99+
assert torch.equal(model_before_opt[layer_name.removeprefix("layer.")], layer_value.cpu()), (
100+
f"Mismatch in {layer_name}: checkpoint saved after optimization instead of before"
101+
)
102+
103+
104+
def test_model_checkpoint_manual_opt_warning():
105+
"""Test that a warning is raised when using manual optimization without saving the state."""
106+
107+
class SimpleModuleNoSave(SimpleModule):
108+
def training_step(self, batch, batch_idx):
109+
out = self.layer(batch[0])
110+
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float())
111+
self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True)
112+
113+
# Don't save the model state before optimization
114+
optimizer = self.optimizers()
115+
optimizer.zero_grad()
116+
self.manual_backward(loss)
117+
optimizer.step()
118+
return loss
119+
120+
with tempfile.TemporaryDirectory() as tmpdir:
121+
dataset = FakeDataset()
122+
train_dataloader = DataLoader(dataset, batch_size=1, num_workers=0) # Avoid num_workers warning
123+
model = SimpleModuleNoSave()
124+
125+
# Clear any existing warnings
126+
warnings.filterwarnings("ignore", message=".*num_workers.*")
127+
128+
with warnings.catch_warnings(record=True) as w:
129+
warnings.simplefilter("always") # Always trigger warnings
130+
trainer = Trainer(
131+
max_epochs=1,
132+
callbacks=[
133+
ModelCheckpoint(
134+
save_top_k=1,
135+
monitor="loss",
136+
dirpath=tmpdir,
137+
mode="min",
138+
save_last=False,
139+
every_n_train_steps=1,
140+
train_time_interval=None,
141+
every_n_epochs=0,
142+
save_on_train_epoch_end=True,
143+
save_weights_only=True,
144+
)
145+
],
146+
log_every_n_steps=1,
147+
num_sanity_val_steps=0,
148+
)
149+
trainer.fit(model, train_dataloader)
150+
151+
# Find our warning in the list of warnings
152+
manual_opt_warnings = [
153+
str(warning.message)
154+
for warning in w
155+
if "Using ModelCheckpoint with manual optimization and every_n_train_steps" in str(warning.message)
156+
]
157+
158+
# Verify our warning was raised
159+
assert len(manual_opt_warnings) > 0, "Expected warning about manual optimization not found"
160+
assert "The checkpoint will contain the model state AFTER optimization" in manual_opt_warnings[0]

0 commit comments

Comments
 (0)