Skip to content

Commit 08266a9

Browse files
authored
update checkpointing (#20653)
1 parent 29ed24f commit 08266a9

File tree

8 files changed

+252
-16
lines changed

8 files changed

+252
-16
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ docs-fabric: clean sphinx-theme
5353
cd docs/source-fabric && $(MAKE) html --jobs $(nproc)
5454

5555
docs-pytorch: clean sphinx-theme
56-
pip install -e .[all] --quiet -r requirements/pytorch/docs.txt -r _notebooks/.actions/requires.txt
56+
pip install -e .[all] --quiet -r requirements/pytorch/docs.txt
5757
cd docs/source-pytorch && $(MAKE) html --jobs $(nproc)
5858

5959
update:

requirements/pytorch/docs.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ nbformat # used for generate empty notebook
44
ipython[notebook] <8.7.0
55
setuptools<58.0 # workaround for `error in ipython setup command: use_2to3 is invalid.`
66

7-
-r ../../_notebooks/.actions/requires.txt
7+
#-r ../../_notebooks/.actions/requires.txt

src/lightning/pytorch/trainer/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
""""""
1514

1615
from lightning.fabric.utilities.seed import seed_everything
1716
from lightning.pytorch.trainer.trainer import Trainer

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import inspect
1515
import logging
1616
import os
1717
from collections.abc import Sequence
1818
from datetime import timedelta
1919
from typing import Optional, Union
2020

21+
from lightning_utilities import module_available
22+
2123
import lightning.pytorch as pl
2224
from lightning.fabric.utilities.registry import _load_external_callbacks
2325
from lightning.pytorch.callbacks import (
@@ -91,7 +93,24 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
9193
" but found `ModelCheckpoint` in callbacks list."
9294
)
9395
elif enable_checkpointing:
94-
self.trainer.callbacks.append(ModelCheckpoint())
96+
if module_available("litmodels") and self.trainer._model_registry:
97+
trainer_source = inspect.getmodule(self.trainer)
98+
if trainer_source is None or not isinstance(trainer_source.__package__, str):
99+
raise RuntimeError("Unable to determine the source of the trainer.")
100+
# this need to imported based on the actual package lightning/pytorch_lightning
101+
if "pytorch_lightning" in trainer_source.__package__:
102+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
103+
else:
104+
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
105+
106+
model_checkpoint = LitModelCheckpoint(model_name=self.trainer._model_registry)
107+
else:
108+
rank_zero_info(
109+
"You are using the plain ModelCheckpoint callback."
110+
" Consider using LitModelCheckpoint which with seamless uploading to Model registry."
111+
)
112+
model_checkpoint = ModelCheckpoint()
113+
self.trainer.callbacks.append(model_checkpoint)
95114

96115
def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
97116
if not enable_model_summary:

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from fsspec.core import url_to_fs
2121
from fsspec.implementations.local import LocalFileSystem
22+
from lightning_utilities import module_available
2223
from torch import Tensor
2324

2425
import lightning.pytorch as pl
@@ -33,6 +34,10 @@
3334
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
3435
from lightning.pytorch.utilities.migration import pl_legacy_patch
3536
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
37+
from lightning.pytorch.utilities.model_registry import (
38+
_is_registry,
39+
find_model_local_ckpt_path,
40+
)
3641
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3742

3843
log = logging.getLogger(__name__)
@@ -48,8 +53,7 @@ def __init__(self, trainer: "pl.Trainer") -> None:
4853

4954
@property
5055
def _hpc_resume_path(self) -> Optional[str]:
51-
dir_path_hpc = self.trainer.default_root_dir
52-
dir_path_hpc = str(dir_path_hpc)
56+
dir_path_hpc = str(self.trainer.default_root_dir)
5357
fs, path = url_to_fs(dir_path_hpc)
5458
if not _is_dir(fs, path):
5559
return None
@@ -194,10 +198,17 @@ def _parse_ckpt_path(
194198
if not self._hpc_resume_path:
195199
raise ValueError(
196200
f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.'
197-
" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
201+
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
198202
)
199203
ckpt_path = self._hpc_resume_path
200204

205+
elif _is_registry(ckpt_path) and module_available("litmodels"):
206+
ckpt_path = find_model_local_ckpt_path(
207+
ckpt_path,
208+
default_model_registry=self.trainer._model_registry,
209+
default_root_dir=self.trainer.default_root_dir,
210+
)
211+
201212
if not ckpt_path:
202213
raise ValueError(
203214
f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"

src/lightning/pytorch/trainer/trainer.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from weakref import proxy
3131

3232
import torch
33+
from lightning_utilities import module_available
3334
from torch.optim import Optimizer
3435

3536
import lightning.pytorch as pl
@@ -70,6 +71,7 @@
7071
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile
7172
from lightning.pytorch.utilities.exceptions import MisconfigurationException
7273
from lightning.pytorch.utilities.model_helpers import is_overridden
74+
from lightning.pytorch.utilities.model_registry import _is_registry, download_model_from_registry
7375
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
7476
from lightning.pytorch.utilities.seed import isolate_rng
7577
from lightning.pytorch.utilities.types import (
@@ -129,6 +131,7 @@ def __init__(
129131
reload_dataloaders_every_n_epochs: int = 0,
130132
default_root_dir: Optional[_PATH] = None,
131133
enable_autolog_hparams: bool = True,
134+
model_registry: Optional[str] = None,
132135
) -> None:
133136
r"""Customize every aspect of training via flags.
134137
@@ -294,6 +297,8 @@ def __init__(
294297
enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
295298
Default: ``True``.
296299
300+
model_registry: The name of the model being uploaded to Model hub.
301+
297302
Raises:
298303
TypeError:
299304
If ``gradient_clip_val`` is not an int or float.
@@ -308,6 +313,9 @@ def __init__(
308313
if default_root_dir is not None:
309314
default_root_dir = os.fspath(default_root_dir)
310315

316+
# remove version if accidentally passed
317+
self._model_registry = model_registry.split(":")[0] if model_registry else None
318+
311319
self.barebones = barebones
312320
if barebones:
313321
# opt-outs
@@ -525,7 +533,20 @@ def fit(
525533
the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
526534
527535
ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
528-
keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
536+
keywords ``"last"``, ``"hpc"`` and ``"registry"``.
537+
Otherwise, if there is no checkpoint file at the path, an exception is raised.
538+
539+
- best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
540+
- last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
541+
- registry: the model will be downloaded from the Lightning Model Registry with following notations:
542+
543+
- ``'registry'``: uses the latest/default version of default model set
544+
with ``Tainer(..., model_registry="my-model")``
545+
- ``'registry:model-name'``: uses the latest/default version of this model `model-name`
546+
- ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
547+
- ``'registry:version:v2'``: uses the default model set
548+
with ``Tainer(..., model_registry="my-model")`` and version 'v2'
549+
529550
530551
Raises:
531552
TypeError:
@@ -573,6 +594,8 @@ def _fit_impl(
573594
)
574595

575596
assert self.state.fn is not None
597+
if _is_registry(ckpt_path) and module_available("litmodels"):
598+
download_model_from_registry(ckpt_path, self)
576599
ckpt_path = self._checkpoint_connector._select_ckpt_path(
577600
self.state.fn,
578601
ckpt_path,
@@ -602,8 +625,8 @@ def validate(
602625
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
603626
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
604627
605-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
606-
If ``None`` and the model instance was passed, use the current weights.
628+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
629+
to validate. If ``None`` and the model instance was passed, use the current weights.
607630
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
608631
if a checkpoint callback is configured.
609632
@@ -681,6 +704,8 @@ def _validate_impl(
681704
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
682705

683706
assert self.state.fn is not None
707+
if _is_registry(ckpt_path) and module_available("litmodels"):
708+
download_model_from_registry(ckpt_path, self)
684709
ckpt_path = self._checkpoint_connector._select_ckpt_path(
685710
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
686711
)
@@ -711,8 +736,8 @@ def test(
711736
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
712737
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
713738
714-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
715-
If ``None`` and the model instance was passed, use the current weights.
739+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
740+
to test. If ``None`` and the model instance was passed, use the current weights.
716741
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
717742
if a checkpoint callback is configured.
718743
@@ -790,6 +815,8 @@ def _test_impl(
790815
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
791816

792817
assert self.state.fn is not None
818+
if _is_registry(ckpt_path) and module_available("litmodels"):
819+
download_model_from_registry(ckpt_path, self)
793820
ckpt_path = self._checkpoint_connector._select_ckpt_path(
794821
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
795822
)
@@ -826,8 +853,8 @@ def predict(
826853
return_predictions: Whether to return predictions.
827854
``True`` by default except when an accelerator that spawns processes is used (not supported).
828855
829-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
830-
If ``None`` and the model instance was passed, use the current weights.
856+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
857+
to predict. If ``None`` and the model instance was passed, use the current weights.
831858
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
832859
if a checkpoint callback is configured.
833860
@@ -899,6 +926,8 @@ def _predict_impl(
899926
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
900927

901928
assert self.state.fn is not None
929+
if _is_registry(ckpt_path) and module_available("litmodels"):
930+
download_model_from_registry(ckpt_path, self)
902931
ckpt_path = self._checkpoint_connector._select_ckpt_path(
903932
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
904933
)

0 commit comments

Comments
 (0)