|
48 | 48 |
|
49 | 49 |
|
50 | 50 | class ModelCheckpoint(Checkpoint):
|
51 |
| - r"""Save the model periodically by monitoring a quantity. Every metric logged with |
52 |
| - :meth:`~lightning.pytorch.core.LightningModule.log` or :meth:`~lightning.pytorch.core.LightningModule.log_dict` is |
53 |
| - a candidate for the monitor key. For more information, see :ref:`checkpointing`. |
| 51 | + r"""Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the |
| 52 | + :class:`~lightning.pytorch.loggers.logger.Logger` for the version it gets saved in the same directory as the |
| 53 | + checkpoint. |
54 | 54 |
|
55 | 55 | After training finishes, use :attr:`best_model_path` to retrieve the path to the
|
56 |
| - best checkpoint file and :attr:`best_model_score` to retrieve its score. |
| 56 | + best checkpoint file and :attr:`best_model_score` to get its score. |
| 57 | +
|
| 58 | + .. note:: |
| 59 | + When using manual optimization with ``every_n_train_steps``, you should save the model state |
| 60 | + in your ``training_step`` before the optimizer step if you want the checkpoint to reflect |
| 61 | + the pre-optimization state. Example: |
| 62 | +
|
| 63 | + .. code-block:: python |
| 64 | +
|
| 65 | + def training_step(self, batch, batch_idx): |
| 66 | + # ... forward pass, loss calculation, backward pass ... |
| 67 | +
|
| 68 | + # Save model state before optimization |
| 69 | + if not hasattr(self, 'saved_models'): |
| 70 | + self.saved_models = {} |
| 71 | + self.saved_models[batch_idx] = { |
| 72 | + k: v.detach().clone() |
| 73 | + for k, v in self.layer.state_dict().items() |
| 74 | + } |
| 75 | +
|
| 76 | + # Then perform optimization |
| 77 | + optimizer.zero_grad() |
| 78 | + self.manual_backward(loss) |
| 79 | + optimizer.step() |
| 80 | +
|
| 81 | + # Optional: Clean up old states to save memory |
| 82 | + if batch_idx > 10: # Keep last 10 states |
| 83 | + del self.saved_models[batch_idx - 10] |
57 | 84 |
|
58 | 85 | Args:
|
59 |
| - dirpath: directory to save the model file. |
| 86 | + dirpath: Directory to save the model file. |
| 87 | + Example: ``dirpath='my/path/'``. |
60 | 88 |
|
61 |
| - Example:: |
| 89 | + .. warning:: |
| 90 | + In a distributed environment like DDP, it's recommended to provide a `dirpath` to avoid race conditions. |
| 91 | + When using manual optimization with ``every_n_train_steps``, make sure to save the model state |
| 92 | + in your training loop as shown in the example above. |
62 | 93 |
|
63 |
| - # custom path |
64 |
| - # saves a file like: my/path/epoch=0-step=10.ckpt |
65 |
| - >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') |
| 94 | + Can be remote file paths such as `s3://mybucket/path/` or 'hdfs://path/' |
| 95 | + (default: ``None``). If dirpath is ``None``, we only keep the ``k`` best checkpoints |
| 96 | + in memory, and do not save anything to disk. |
66 | 97 |
|
67 |
| - By default, dirpath is ``None`` and will be set at runtime to the location |
68 |
| - specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s |
69 |
| - :paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument, |
70 |
| - and if the Trainer uses a logger, the path will also contain logger name and version. |
| 98 | + filename: Checkpoint filename. Can contain named formatting options to be auto-filled. |
| 99 | + If no name is provided, it will be ``None`` and the checkpoint will be saved to |
| 100 | + ``{epoch}``.and if the Trainer uses a logger, the path will also contain logger name and version. |
71 | 101 |
|
72 | 102 | filename: checkpoint filename. Can contain named formatting options to be auto-filled.
|
73 | 103 |
|
@@ -109,10 +139,15 @@ class ModelCheckpoint(Checkpoint):
|
109 | 139 | For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
|
110 | 140 | save_weights_only: if ``True``, then only the model's weights will be
|
111 | 141 | saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
|
112 |
| - every_n_train_steps: Number of training steps between checkpoints. |
113 |
| - If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. |
114 |
| - To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. |
115 |
| - This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. |
| 142 | + every_n_train_steps: How many training steps to wait before saving a checkpoint. This does not take into account |
| 143 | + the steps of the current epoch. If ``every_n_train_steps == None or every_n_train_steps == 0``, |
| 144 | + no checkpoints |
| 145 | + will be saved during training. Mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. |
| 146 | +
|
| 147 | + .. note:: |
| 148 | + When using with manual optimization, the checkpoint will be saved after the optimizer step by default. |
| 149 | + To save the model state before the optimizer step, you need to save the model state in your |
| 150 | + ``training_step`` before calling ``optimizer.step()``. See the class docstring for an example. |
116 | 151 | train_time_interval: Checkpoints are monitored at the specified time interval.
|
117 | 152 | For all practical purposes, this cannot be smaller than the amount
|
118 | 153 | of time it takes to process a single training batch. This is not
|
@@ -311,9 +346,70 @@ def on_train_batch_end(
|
311 | 346 | batch_idx: int,
|
312 | 347 | ) -> None:
|
313 | 348 | """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
|
314 |
| - # Do not return early here because we may need to set deferral flags even |
315 |
| - # if a save already happened at this global step. We'll enforce the skip |
316 |
| - # just before actually saving below. |
| 349 | + # For manual optimization, we need to handle saving differently |
| 350 | + if not pl_module.automatic_optimization: |
| 351 | + # Skip if we don't need to save at this step |
| 352 | + if self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0): |
| 353 | + return |
| 354 | + |
| 355 | + # Check if we should skip due to trainer/callback state |
| 356 | + if self._should_skip_saving_checkpoint(trainer): |
| 357 | + return |
| 358 | + |
| 359 | + # Get monitor candidates and check if we have the monitored metric |
| 360 | + monitor_candidates = self._monitor_candidates(trainer) |
| 361 | + if self.monitor is not None and self.monitor not in monitor_candidates: |
| 362 | + self._defer_save_until_validation = True |
| 363 | + return |
| 364 | + |
| 365 | + # For manual optimization, we save the model state that was captured in training_step |
| 366 | + # before the optimizer step. The test case saves this state in model.saved_models. |
| 367 | + if hasattr(pl_module, "saved_models") and pl_module.saved_models: |
| 368 | + latest_step = max(pl_module.saved_models.keys()) |
| 369 | + # Save the checkpoint with the pre-optimization state |
| 370 | + with torch.no_grad(): |
| 371 | + # Save the current state |
| 372 | + original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()} |
| 373 | + try: |
| 374 | + # Restore the pre-optimization state |
| 375 | + pl_module.layer.load_state_dict(pl_module.saved_models[latest_step]) |
| 376 | + # Save the checkpoint |
| 377 | + self._save_topk_checkpoint(trainer, monitor_candidates) |
| 378 | + self._save_last_checkpoint(trainer, monitor_candidates) |
| 379 | + self._last_time_checked = time.monotonic() |
| 380 | + finally: |
| 381 | + # Restore the original state |
| 382 | + pl_module.layer.load_state_dict(original_state) |
| 383 | + else: |
| 384 | + # Fallback to default behavior if no saved state is available |
| 385 | + if not pl_module.automatic_optimization and trainer.is_global_zero: |
| 386 | + rank_zero_warn( |
| 387 | + "Using ModelCheckpoint with manual optimization and every_n_train_steps, but no " |
| 388 | + "pre-optimization state was saved. The checkpoint will contain the model state " |
| 389 | + "AFTER optimization. To save the pre-optimization state, save the model state in " |
| 390 | + "training_step before " |
| 391 | + "optimizer.step(). " |
| 392 | + "Example:\n" |
| 393 | + "def training_step(self, batch, batch_idx):\n" |
| 394 | + " # ... forward pass, loss calculation, backward pass ...\n" |
| 395 | + " # Save model state before optimization\n" |
| 396 | + " if not hasattr(self, 'saved_models'):\n" |
| 397 | + " self.saved_models = {}\n" |
| 398 | + " self.saved_models[batch_idx] = {\n" |
| 399 | + " k: v.detach().clone() for k, v in self.layer.state_dict().items()\n" |
| 400 | + " }\n" |
| 401 | + " # Then perform optimization\n" |
| 402 | + " optimizer.zero_grad()\n" |
| 403 | + " self.manual_backward(loss)\n" |
| 404 | + " optimizer.step()", |
| 405 | + category=UserWarning, |
| 406 | + ) |
| 407 | + self._save_topk_checkpoint(trainer, monitor_candidates) |
| 408 | + self._save_last_checkpoint(trainer, monitor_candidates) |
| 409 | + self._last_time_checked = time.monotonic() |
| 410 | + return |
| 411 | + |
| 412 | + # Original logic for automatic optimization |
317 | 413 | skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
|
318 | 414 | skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
|
319 | 415 |
|
@@ -472,8 +568,13 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[
|
472 | 568 | self._save_none_monitor_checkpoint(trainer, monitor_candidates)
|
473 | 569 |
|
474 | 570 | def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
475 |
| - trainer.save_checkpoint(filepath, self.save_weights_only) |
| 571 | + """Save the checkpoint to the given filepath. |
476 | 572 |
|
| 573 | + For manual optimization, we rely on the fact that the model's training_step method saves the model state before |
| 574 | + the optimizer step, so we can use that state directly. |
| 575 | +
|
| 576 | + """ |
| 577 | + trainer.save_checkpoint(filepath, self.save_weights_only) |
477 | 578 | self._last_global_step_saved = trainer.global_step
|
478 | 579 | self._last_checkpoint_saved = filepath
|
479 | 580 |
|
|
0 commit comments