3030from weakref import proxy
3131
3232import torch
33+ from lightning_utilities import module_available
3334from torch .optim import Optimizer
3435
3536import lightning .pytorch as pl
7071from lightning .pytorch .utilities .compile import _maybe_unwrap_optimized , _verify_strategy_supports_compile
7172from lightning .pytorch .utilities .exceptions import MisconfigurationException
7273from lightning .pytorch .utilities .model_helpers import is_overridden
74+ from lightning .pytorch .utilities .model_registry import _is_registry , download_model_from_registry
7375from lightning .pytorch .utilities .rank_zero import rank_zero_info , rank_zero_warn
7476from lightning .pytorch .utilities .seed import isolate_rng
7577from lightning .pytorch .utilities .types import (
@@ -129,6 +131,7 @@ def __init__(
129131 reload_dataloaders_every_n_epochs : int = 0 ,
130132 default_root_dir : Optional [_PATH ] = None ,
131133 enable_autolog_hparams : bool = True ,
134+ model_registry : Optional [str ] = None ,
132135 ) -> None :
133136 r"""Customize every aspect of training via flags.
134137
@@ -294,6 +297,8 @@ def __init__(
294297 enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
295298 Default: ``True``.
296299
300+ model_registry: The name of the model being uploaded to Model hub.
301+
297302 Raises:
298303 TypeError:
299304 If ``gradient_clip_val`` is not an int or float.
@@ -308,6 +313,9 @@ def __init__(
308313 if default_root_dir is not None :
309314 default_root_dir = os .fspath (default_root_dir )
310315
316+ # remove version if accidentally passed
317+ self ._model_registry = model_registry .split (":" )[0 ] if model_registry else None
318+
311319 self .barebones = barebones
312320 if barebones :
313321 # opt-outs
@@ -525,7 +533,20 @@ def fit(
525533 the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
526534
527535 ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
528- keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
536+ keywords ``"last"``, ``"hpc"`` and ``"registry"``.
537+ Otherwise, if there is no checkpoint file at the path, an exception is raised.
538+
539+ - best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
540+ - last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
541+ - registry: the model will be downloaded from the Lightning Model Registry with following notations:
542+
543+ - ``'registry'``: uses the latest/default version of default model set
544+ with ``Tainer(..., model_registry="my-model")``
545+ - ``'registry:model-name'``: uses the latest/default version of this model `model-name`
546+ - ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
547+ - ``'registry:version:v2'``: uses the default model set
548+ with ``Tainer(..., model_registry="my-model")`` and version 'v2'
549+
529550
530551 Raises:
531552 TypeError:
@@ -573,6 +594,8 @@ def _fit_impl(
573594 )
574595
575596 assert self .state .fn is not None
597+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
598+ download_model_from_registry (ckpt_path , self )
576599 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
577600 self .state .fn ,
578601 ckpt_path ,
@@ -602,8 +625,8 @@ def validate(
602625 Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
603626 the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
604627
605- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
606- If ``None`` and the model instance was passed, use the current weights.
628+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
629+ to validate. If ``None`` and the model instance was passed, use the current weights.
607630 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
608631 if a checkpoint callback is configured.
609632
@@ -681,6 +704,8 @@ def _validate_impl(
681704 self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
682705
683706 assert self .state .fn is not None
707+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
708+ download_model_from_registry (ckpt_path , self )
684709 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
685710 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
686711 )
@@ -711,8 +736,8 @@ def test(
711736 Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
712737 the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
713738
714- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
715- If ``None`` and the model instance was passed, use the current weights.
739+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
740+ to test. If ``None`` and the model instance was passed, use the current weights.
716741 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
717742 if a checkpoint callback is configured.
718743
@@ -790,6 +815,8 @@ def _test_impl(
790815 self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
791816
792817 assert self .state .fn is not None
818+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
819+ download_model_from_registry (ckpt_path , self )
793820 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
794821 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
795822 )
@@ -826,8 +853,8 @@ def predict(
826853 return_predictions: Whether to return predictions.
827854 ``True`` by default except when an accelerator that spawns processes is used (not supported).
828855
829- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
830- If ``None`` and the model instance was passed, use the current weights.
856+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
857+ to predict. If ``None`` and the model instance was passed, use the current weights.
831858 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
832859 if a checkpoint callback is configured.
833860
@@ -899,6 +926,8 @@ def _predict_impl(
899926 self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
900927
901928 assert self .state .fn is not None
929+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
930+ download_model_from_registry (ckpt_path , self )
902931 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
903932 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
904933 )
0 commit comments