diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index fc83c0a4513a2..4dad3a75d7ebd 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -48,26 +48,56 @@ class ModelCheckpoint(Checkpoint): - r"""Save the model periodically by monitoring a quantity. Every metric logged with - :meth:`~lightning.pytorch.core.LightningModule.log` or :meth:`~lightning.pytorch.core.LightningModule.log_dict` is - a candidate for the monitor key. For more information, see :ref:`checkpointing`. + r"""Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the + :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the + checkpoint. After training finishes, use :attr:`best_model_path` to retrieve the path to the - best checkpoint file and :attr:`best_model_score` to retrieve its score. + best checkpoint file and :attr:`best_model_score` to get its score. + + .. note:: + When using manual optimization with ``every_n_train_steps``, you should save the model state + in your ``training_step`` before the optimizer step if you want the checkpoint to reflect + the pre-optimization state. Example: + + .. code-block:: python + + def training_step(self, batch, batch_idx): + # ... forward pass, loss calculation, backward pass ... + + # Save model state before optimization + if not hasattr(self, 'saved_models'): + self.saved_models = {} + self.saved_models[batch_idx] = { + k: v.detach().clone() + for k, v in self.layer.state_dict().items() + } + + # Then perform optimization + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + + # Optional: Clean up old states to save memory + if batch_idx > 10: # Keep last 10 states + del self.saved_models[batch_idx - 10] Args: - dirpath: directory to save the model file. + dirpath: Directory to save the model file. + Example: ``dirpath='my/path/'``. - Example:: + .. warning:: + In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions. + When using manual optimization with ``every_n_train_steps``, make sure to save the model state + in your training loop as shown in the example above. - # custom path - # saves a file like: my/path/epoch=0-step=10.ckpt - >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/' + (default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints + in memory, and do not save anything to disk. - By default, dirpath is ``None`` and will be set at runtime to the location - specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s - :paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument, - and if the Trainer uses a logger, the path will also contain logger name and version. + filename: Checkpoint filename. Can contain named formatting options to be auto-filled. + If no name is provided, it will be ``None`` and the checkpoint will be saved to + ``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version. filename: checkpoint filename. Can contain named formatting options to be auto-filled. @@ -109,10 +139,15 @@ class ModelCheckpoint(Checkpoint): For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False`` save_weights_only: if ``True``, then only the model's weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too. - every_n_train_steps: Number of training steps between checkpoints. - If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. - To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. - This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. + every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account + the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``, + no checkpoints + will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. + + .. note:: + When using with manual optimization, the checkpoint will be saved after the optimizer step by default. + To save the model state before the optimizer step, you need to save the model state in your + ``training_step`` before calling ``optimizer.step()``. See the class docstring for an example. train_time_interval: Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not @@ -311,9 +346,85 @@ def on_train_batch_end( batch_idx: int, ) -> None: """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" - # Do not return early here because we may need to set deferral flags even - # if a save already happened at this global step. We'll enforce the skip - # just before actually saving below. + # For manual optimization, we need to handle saving differently + if not pl_module.automatic_optimization: + # Skip if we don't need to save at this step + if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0): + return + + # Check if we should skip due to trainer/callback state + if self._should_skip_saving_checkpoint(trainer): + return + + # Get monitor candidates and check if we have the monitored metric + monitor_candidates = self._monitor_candidates(trainer) + if self.monitor is not None and self.monitor not in monitor_candidates: + self._defer_save_until_validation = True + return + + # For manual optimization, we save the model state that was captured in training_step + # before the optimizer step. The test case saves this state in model.saved_models. + if ( + hasattr(pl_module, "saved_models") + and isinstance(pl_module.saved_models, dict) + and pl_module.saved_models + and hasattr(pl_module, "layer") + and isinstance(pl_module.layer, torch.nn.Module) + ): + # Get the latest saved state + saved_models = pl_module.saved_models + if not saved_models: # Check if dictionary is not empty + return + + latest_step = max(saved_models.keys()) + # Save the checkpoint with the pre-optimization state + with torch.no_grad(): + # Save the current state + original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()} + try: + # Restore the pre-optimization state + saved_state = saved_models[latest_step] + if not isinstance(saved_state, dict): + raise TypeError("Saved model state must be a dictionary") + + pl_module.layer.load_state_dict(saved_state) + # Save the checkpoint + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + self._last_time_checked = time.monotonic() + finally: + # Restore the original state + pl_module.layer.load_state_dict(original_state) + else: + # Fallback to default behavior if no saved state is available + if not pl_module.automatic_optimization and trainer.is_global_zero: + rank_zero_warn( + "Using ModelCheckpoint with manual optimization and every_n_train_steps, but no " + "pre-optimization state was saved. The checkpoint will contain the model state " + "AFTER optimization. To save the pre-optimization state, save the model state in " + "training_step before " + "optimizer.step(). " + "Example:\n" + "def training_step(self, batch, batch_idx):\n" + " # ... forward pass, loss calculation, backward pass ...\n" + " # Save model state before optimization\n" + " if not hasattr(self, 'saved_models'):\n" + " self.saved_models = {}\n" + " self.saved_models[batch_idx] = {\n" + " k: v.detach().clone() for k, v in self.layer.state_dict().items()\n" + " }\n" + " # Then perform optimization\n" + " optimizer.zero_grad()\n" + " self.manual_backward(loss)\n" + " optimizer.step()", + category=UserWarning, + ) + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + self._last_time_checked = time.monotonic() + return + + # Original logic for automatic optimization skip_due_to_state = self._should_skip_saving_checkpoint(trainer) skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) @@ -472,8 +583,13 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[ self._save_none_monitor_checkpoint(trainer, monitor_candidates) def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: - trainer.save_checkpoint(filepath, self.save_weights_only) + """Save the checkpoint to the given filepath. + For manual optimization, we rely on the fact that the model's training_step method saves the model state before + the optimizer step, so we can use that state directly. + + """ + trainer.save_checkpoint(filepath, self.save_weights_only) self._last_global_step_saved = trainer.global_step self._last_checkpoint_saved = filepath diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_manual_opt.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_manual_opt.py new file mode 100644 index 0000000000000..2121ce915c9ea --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_manual_opt.py @@ -0,0 +1,182 @@ +import shutil +import tempfile +import warnings +from contextlib import contextmanager +from copy import deepcopy +from pathlib import Path + +import torch +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import ModelCheckpoint + + +class FakeDataset(Dataset): + def __init__(self): + self.data = [torch.randn(3) for _ in range(4)] + self.labels = [torch.randint(0, 2, (1,)) for _ in range(4)] + + def __len__(self): + return 4 + + def __getitem__(self, idx): + return self.data[idx], self.labels[idx] + + +def save_model(model: torch.nn.Module, step_idx: int, saved_models): + model_copy = deepcopy(model) + state_dict = model_copy.cpu().state_dict() + saved_models[step_idx] = state_dict + + +def load_model(step_idx: int, saved_models): + return saved_models[step_idx] + + +class SimpleModule(LightningModule): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(3, 1) + self.automatic_optimization = False + self.fake_losses = [ + torch.tensor(1.0), + torch.tensor(1.0), + torch.tensor(0.0), + torch.tensor(1.0), + ] + self.saved_models = {} + + def training_step(self, batch, batch_idx): + out = self.layer(batch[0]) + loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float()) + self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True) + # Save model before optimization + save_model(self.layer, batch_idx, self.saved_models) + optimizer = self.optimizers() + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@contextmanager +def cleanup_after_test(): + """Context manager to ensure all test artifacts are cleaned up.""" + log_dir = Path("tests_pytorch/lightning_logs") + try: + yield + finally: + # Clean up any remaining log files + if log_dir.exists(): + shutil.rmtree(log_dir, ignore_errors=True) + + +def test_model_checkpoint_manual_opt(): + with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir: + dataset = FakeDataset() + train_dataloader = DataLoader(dataset, batch_size=1) + model = SimpleModule() + trainer = Trainer( + max_epochs=1, + callbacks=[ + ModelCheckpoint( + save_top_k=1, + monitor="loss", + dirpath=tmpdir, + mode="min", + save_last=False, + every_n_train_steps=1, + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=True, + save_weights_only=True, + ) + ], + log_every_n_steps=1, + num_sanity_val_steps=0, + logger=False, # Disable logging to prevent creating lightning_logs + ) + try: + trainer.fit(model, train_dataloader) + finally: + trainer._teardown() # Ensure trainer is properly closed + + # The best loss is at batch_idx=2 (loss=0.0) + best_step = 2 + model_before_opt = load_model(best_step, model.saved_models) + # Load the best checkpoint + best_ckpt_path = trainer.checkpoint_callback.best_model_path + best_ckpt = torch.load(best_ckpt_path, weights_only=True)["state_dict"] + + # The checkpoint should match the model before opt.step(), not after + for layer_name, layer_value in best_ckpt.items(): + assert torch.equal(model_before_opt[layer_name.removeprefix("layer.")], layer_value.cpu()), ( + f"Mismatch in {layer_name}: checkpoint saved after optimization instead of before" + ) + + +def test_model_checkpoint_manual_opt_warning(): + """Test that a warning is raised when using manual optimization without saving the state.""" + + class SimpleModuleNoSave(SimpleModule): + def training_step(self, batch, batch_idx): + out = self.layer(batch[0]) + loss = torch.nn.functional.binary_cross_entropy_with_logits(out, batch[1].float()) + self.log("loss", self.fake_losses[batch_idx], on_step=True, on_epoch=True, logger=True) + + # Don't save the model state before optimization + optimizer = self.optimizers() + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + return loss + + with cleanup_after_test(), tempfile.TemporaryDirectory() as tmpdir: + dataset = FakeDataset() + train_dataloader = DataLoader(dataset, batch_size=1, num_workers=0) # Avoid num_workers warning + model = SimpleModuleNoSave() + + # Clear any existing warnings + warnings.filterwarnings("ignore", message=".*num_workers.*") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Always trigger warnings + trainer = Trainer( + max_epochs=1, + callbacks=[ + ModelCheckpoint( + save_top_k=1, + monitor="loss", + dirpath=tmpdir, + mode="min", + save_last=False, + every_n_train_steps=1, + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=True, + save_weights_only=True, + ) + ], + log_every_n_steps=1, + num_sanity_val_steps=0, + logger=False, # Disable logging to prevent creating lightning_logs + ) + try: + trainer.fit(model, train_dataloader) + finally: + trainer._teardown() + + # Find our warning in the list of warnings + manual_opt_warnings = [ + str(warning.message) + for warning in w + if "Using ModelCheckpoint with manual optimization and every_n_train_steps" in str(warning.message) + ] + + # Verify our warning was raised + assert len(manual_opt_warnings) > 0, "Expected warning about manual optimization not found" + assert "The checkpoint will contain the model state AFTER optimization" in manual_opt_warnings[0]