Skip to content

Commit 91b147a

Browse files
Bordalexierule
authored andcommitted
update checkpointing (#20653)
1 parent de7f462 commit 91b147a

File tree

6 files changed

+250
-14
lines changed

6 files changed

+250
-14
lines changed

src/lightning/pytorch/trainer/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
""""""
1514

1615
from lightning.fabric.utilities.seed import seed_everything
1716
from lightning.pytorch.trainer.trainer import Trainer

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import inspect
1515
import logging
1616
import os
1717
from collections.abc import Sequence
1818
from datetime import timedelta
1919
from typing import Optional, Union
2020

21+
from lightning_utilities import module_available
22+
2123
import lightning.pytorch as pl
2224
from lightning.fabric.utilities.registry import _load_external_callbacks
2325
from lightning.pytorch.callbacks import (
@@ -91,7 +93,24 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
9193
" but found `ModelCheckpoint` in callbacks list."
9294
)
9395
elif enable_checkpointing:
94-
self.trainer.callbacks.append(ModelCheckpoint())
96+
if module_available("litmodels") and self.trainer._model_registry:
97+
trainer_source = inspect.getmodule(self.trainer)
98+
if trainer_source is None or not isinstance(trainer_source.__package__, str):
99+
raise RuntimeError("Unable to determine the source of the trainer.")
100+
# this need to imported based on the actual package lightning/pytorch_lightning
101+
if "pytorch_lightning" in trainer_source.__package__:
102+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
103+
else:
104+
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
105+
106+
model_checkpoint = LitModelCheckpoint(model_name=self.trainer._model_registry)
107+
else:
108+
rank_zero_info(
109+
"You are using the plain ModelCheckpoint callback."
110+
" Consider using LitModelCheckpoint which with seamless uploading to Model registry."
111+
)
112+
model_checkpoint = ModelCheckpoint()
113+
self.trainer.callbacks.append(model_checkpoint)
95114

96115
def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
97116
if not enable_model_summary:

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from fsspec.core import url_to_fs
2121
from fsspec.implementations.local import LocalFileSystem
22+
from lightning_utilities import module_available
2223
from torch import Tensor
2324

2425
import lightning.pytorch as pl
@@ -33,6 +34,10 @@
3334
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
3435
from lightning.pytorch.utilities.migration import pl_legacy_patch
3536
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
37+
from lightning.pytorch.utilities.model_registry import (
38+
_is_registry,
39+
find_model_local_ckpt_path,
40+
)
3641
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3742

3843
log = logging.getLogger(__name__)
@@ -48,8 +53,7 @@ def __init__(self, trainer: "pl.Trainer") -> None:
4853

4954
@property
5055
def _hpc_resume_path(self) -> Optional[str]:
51-
dir_path_hpc = self.trainer.default_root_dir
52-
dir_path_hpc = str(dir_path_hpc)
56+
dir_path_hpc = str(self.trainer.default_root_dir)
5357
fs, path = url_to_fs(dir_path_hpc)
5458
if not _is_dir(fs, path):
5559
return None
@@ -194,10 +198,17 @@ def _parse_ckpt_path(
194198
if not self._hpc_resume_path:
195199
raise ValueError(
196200
f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.'
197-
" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
201+
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
198202
)
199203
ckpt_path = self._hpc_resume_path
200204

205+
elif _is_registry(ckpt_path) and module_available("litmodels"):
206+
ckpt_path = find_model_local_ckpt_path(
207+
ckpt_path,
208+
default_model_registry=self.trainer._model_registry,
209+
default_root_dir=self.trainer.default_root_dir,
210+
)
211+
201212
if not ckpt_path:
202213
raise ValueError(
203214
f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"

src/lightning/pytorch/trainer/trainer.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from weakref import proxy
3131

3232
import torch
33+
from lightning_utilities import module_available
3334
from torch.optim import Optimizer
3435

3536
import lightning.pytorch as pl
@@ -70,6 +71,7 @@
7071
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile
7172
from lightning.pytorch.utilities.exceptions import MisconfigurationException
7273
from lightning.pytorch.utilities.model_helpers import is_overridden
74+
from lightning.pytorch.utilities.model_registry import _is_registry, download_model_from_registry
7375
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
7476
from lightning.pytorch.utilities.seed import isolate_rng
7577
from lightning.pytorch.utilities.types import (
@@ -128,6 +130,7 @@ def __init__(
128130
sync_batchnorm: bool = False,
129131
reload_dataloaders_every_n_epochs: int = 0,
130132
default_root_dir: Optional[_PATH] = None,
133+
model_registry: Optional[str] = None,
131134
) -> None:
132135
r"""Customize every aspect of training via flags.
133136
@@ -290,6 +293,8 @@ def __init__(
290293
Default: ``os.getcwd()``.
291294
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
292295
296+
model_registry: The name of the model being uploaded to Model hub.
297+
293298
Raises:
294299
TypeError:
295300
If ``gradient_clip_val`` is not an int or float.
@@ -304,6 +309,9 @@ def __init__(
304309
if default_root_dir is not None:
305310
default_root_dir = os.fspath(default_root_dir)
306311

312+
# remove version if accidentally passed
313+
self._model_registry = model_registry.split(":")[0] if model_registry else None
314+
307315
self.barebones = barebones
308316
if barebones:
309317
# opt-outs
@@ -519,7 +527,20 @@ def fit(
519527
the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
520528
521529
ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
522-
keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
530+
keywords ``"last"``, ``"hpc"`` and ``"registry"``.
531+
Otherwise, if there is no checkpoint file at the path, an exception is raised.
532+
533+
- best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
534+
- last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
535+
- registry: the model will be downloaded from the Lightning Model Registry with following notations:
536+
537+
- ``'registry'``: uses the latest/default version of default model set
538+
with ``Tainer(..., model_registry="my-model")``
539+
- ``'registry:model-name'``: uses the latest/default version of this model `model-name`
540+
- ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
541+
- ``'registry:version:v2'``: uses the default model set
542+
with ``Tainer(..., model_registry="my-model")`` and version 'v2'
543+
523544
524545
Raises:
525546
TypeError:
@@ -567,6 +588,8 @@ def _fit_impl(
567588
)
568589

569590
assert self.state.fn is not None
591+
if _is_registry(ckpt_path) and module_available("litmodels"):
592+
download_model_from_registry(ckpt_path, self)
570593
ckpt_path = self._checkpoint_connector._select_ckpt_path(
571594
self.state.fn,
572595
ckpt_path,
@@ -596,8 +619,8 @@ def validate(
596619
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
597620
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
598621
599-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
600-
If ``None`` and the model instance was passed, use the current weights.
622+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
623+
to validate. If ``None`` and the model instance was passed, use the current weights.
601624
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
602625
if a checkpoint callback is configured.
603626
@@ -675,6 +698,8 @@ def _validate_impl(
675698
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
676699

677700
assert self.state.fn is not None
701+
if _is_registry(ckpt_path) and module_available("litmodels"):
702+
download_model_from_registry(ckpt_path, self)
678703
ckpt_path = self._checkpoint_connector._select_ckpt_path(
679704
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
680705
)
@@ -705,8 +730,8 @@ def test(
705730
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
706731
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
707732
708-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
709-
If ``None`` and the model instance was passed, use the current weights.
733+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
734+
to test. If ``None`` and the model instance was passed, use the current weights.
710735
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
711736
if a checkpoint callback is configured.
712737
@@ -784,6 +809,8 @@ def _test_impl(
784809
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
785810

786811
assert self.state.fn is not None
812+
if _is_registry(ckpt_path) and module_available("litmodels"):
813+
download_model_from_registry(ckpt_path, self)
787814
ckpt_path = self._checkpoint_connector._select_ckpt_path(
788815
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
789816
)
@@ -820,8 +847,8 @@ def predict(
820847
return_predictions: Whether to return predictions.
821848
``True`` by default except when an accelerator that spawns processes is used (not supported).
822849
823-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
824-
If ``None`` and the model instance was passed, use the current weights.
850+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
851+
to predict. If ``None`` and the model instance was passed, use the current weights.
825852
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
826853
if a checkpoint callback is configured.
827854
@@ -893,6 +920,8 @@ def _predict_impl(
893920
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
894921

895922
assert self.state.fn is not None
923+
if _is_registry(ckpt_path) and module_available("litmodels"):
924+
download_model_from_registry(ckpt_path, self)
896925
ckpt_path = self._checkpoint_connector._select_ckpt_path(
897926
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
898927
)

0 commit comments

Comments
 (0)