30
30
from weakref import proxy
31
31
32
32
import torch
33
+ from lightning_utilities import module_available
33
34
from torch .optim import Optimizer
34
35
35
36
import lightning .pytorch as pl
70
71
from lightning .pytorch .utilities .compile import _maybe_unwrap_optimized , _verify_strategy_supports_compile
71
72
from lightning .pytorch .utilities .exceptions import MisconfigurationException
72
73
from lightning .pytorch .utilities .model_helpers import is_overridden
74
+ from lightning .pytorch .utilities .model_registry import _is_registry , download_model_from_registry
73
75
from lightning .pytorch .utilities .rank_zero import rank_zero_info , rank_zero_warn
74
76
from lightning .pytorch .utilities .seed import isolate_rng
75
77
from lightning .pytorch .utilities .types import (
@@ -129,6 +131,7 @@ def __init__(
129
131
reload_dataloaders_every_n_epochs : int = 0 ,
130
132
default_root_dir : Optional [_PATH ] = None ,
131
133
enable_autolog_hparams : bool = True ,
134
+ model_registry : Optional [str ] = None ,
132
135
) -> None :
133
136
r"""Customize every aspect of training via flags.
134
137
@@ -294,6 +297,8 @@ def __init__(
294
297
enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
295
298
Default: ``True``.
296
299
300
+ model_registry: The name of the model being uploaded to Model hub.
301
+
297
302
Raises:
298
303
TypeError:
299
304
If ``gradient_clip_val`` is not an int or float.
@@ -308,6 +313,9 @@ def __init__(
308
313
if default_root_dir is not None :
309
314
default_root_dir = os .fspath (default_root_dir )
310
315
316
+ # remove version if accidentally passed
317
+ self ._model_registry = model_registry .split (":" )[0 ] if model_registry else None
318
+
311
319
self .barebones = barebones
312
320
if barebones :
313
321
# opt-outs
@@ -525,7 +533,20 @@ def fit(
525
533
the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
526
534
527
535
ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
528
- keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
536
+ keywords ``"last"``, ``"hpc"`` and ``"registry"``.
537
+ Otherwise, if there is no checkpoint file at the path, an exception is raised.
538
+
539
+ - best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
540
+ - last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
541
+ - registry: the model will be downloaded from the Lightning Model Registry with following notations:
542
+
543
+ - ``'registry'``: uses the latest/default version of default model set
544
+ with ``Tainer(..., model_registry="my-model")``
545
+ - ``'registry:model-name'``: uses the latest/default version of this model `model-name`
546
+ - ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
547
+ - ``'registry:version:v2'``: uses the default model set
548
+ with ``Tainer(..., model_registry="my-model")`` and version 'v2'
549
+
529
550
530
551
Raises:
531
552
TypeError:
@@ -573,6 +594,8 @@ def _fit_impl(
573
594
)
574
595
575
596
assert self .state .fn is not None
597
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
598
+ download_model_from_registry (ckpt_path , self )
576
599
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
577
600
self .state .fn ,
578
601
ckpt_path ,
@@ -602,8 +625,8 @@ def validate(
602
625
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
603
626
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
604
627
605
- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
606
- If ``None`` and the model instance was passed, use the current weights.
628
+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
629
+ to validate. If ``None`` and the model instance was passed, use the current weights.
607
630
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
608
631
if a checkpoint callback is configured.
609
632
@@ -681,6 +704,8 @@ def _validate_impl(
681
704
self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
682
705
683
706
assert self .state .fn is not None
707
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
708
+ download_model_from_registry (ckpt_path , self )
684
709
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
685
710
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
686
711
)
@@ -711,8 +736,8 @@ def test(
711
736
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
712
737
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
713
738
714
- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
715
- If ``None`` and the model instance was passed, use the current weights.
739
+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
740
+ to test. If ``None`` and the model instance was passed, use the current weights.
716
741
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
717
742
if a checkpoint callback is configured.
718
743
@@ -790,6 +815,8 @@ def _test_impl(
790
815
self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
791
816
792
817
assert self .state .fn is not None
818
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
819
+ download_model_from_registry (ckpt_path , self )
793
820
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
794
821
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
795
822
)
@@ -826,8 +853,8 @@ def predict(
826
853
return_predictions: Whether to return predictions.
827
854
``True`` by default except when an accelerator that spawns processes is used (not supported).
828
855
829
- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
830
- If ``None`` and the model instance was passed, use the current weights.
856
+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
857
+ to predict. If ``None`` and the model instance was passed, use the current weights.
831
858
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
832
859
if a checkpoint callback is configured.
833
860
@@ -899,6 +926,8 @@ def _predict_impl(
899
926
self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
900
927
901
928
assert self .state .fn is not None
929
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
930
+ download_model_from_registry (ckpt_path , self )
902
931
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
903
932
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
904
933
)
0 commit comments