30
30
from weakref import proxy
31
31
32
32
import torch
33
+ from lightning_utilities import module_available
33
34
from torch .optim import Optimizer
34
35
35
36
import lightning .pytorch as pl
70
71
from lightning .pytorch .utilities .compile import _maybe_unwrap_optimized , _verify_strategy_supports_compile
71
72
from lightning .pytorch .utilities .exceptions import MisconfigurationException
72
73
from lightning .pytorch .utilities .model_helpers import is_overridden
74
+ from lightning .pytorch .utilities .model_registry import _is_registry , download_model_from_registry
73
75
from lightning .pytorch .utilities .rank_zero import rank_zero_info , rank_zero_warn
74
76
from lightning .pytorch .utilities .seed import isolate_rng
75
77
from lightning .pytorch .utilities .types import (
@@ -128,6 +130,7 @@ def __init__(
128
130
sync_batchnorm : bool = False ,
129
131
reload_dataloaders_every_n_epochs : int = 0 ,
130
132
default_root_dir : Optional [_PATH ] = None ,
133
+ model_registry : Optional [str ] = None ,
131
134
) -> None :
132
135
r"""Customize every aspect of training via flags.
133
136
@@ -290,6 +293,8 @@ def __init__(
290
293
Default: ``os.getcwd()``.
291
294
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
292
295
296
+ model_registry: The name of the model being uploaded to Model hub.
297
+
293
298
Raises:
294
299
TypeError:
295
300
If ``gradient_clip_val`` is not an int or float.
@@ -304,6 +309,9 @@ def __init__(
304
309
if default_root_dir is not None :
305
310
default_root_dir = os .fspath (default_root_dir )
306
311
312
+ # remove version if accidentally passed
313
+ self ._model_registry = model_registry .split (":" )[0 ] if model_registry else None
314
+
307
315
self .barebones = barebones
308
316
if barebones :
309
317
# opt-outs
@@ -519,7 +527,20 @@ def fit(
519
527
the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
520
528
521
529
ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
522
- keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
530
+ keywords ``"last"``, ``"hpc"`` and ``"registry"``.
531
+ Otherwise, if there is no checkpoint file at the path, an exception is raised.
532
+
533
+ - best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
534
+ - last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
535
+ - registry: the model will be downloaded from the Lightning Model Registry with following notations:
536
+
537
+ - ``'registry'``: uses the latest/default version of default model set
538
+ with ``Tainer(..., model_registry="my-model")``
539
+ - ``'registry:model-name'``: uses the latest/default version of this model `model-name`
540
+ - ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
541
+ - ``'registry:version:v2'``: uses the default model set
542
+ with ``Tainer(..., model_registry="my-model")`` and version 'v2'
543
+
523
544
524
545
Raises:
525
546
TypeError:
@@ -567,6 +588,8 @@ def _fit_impl(
567
588
)
568
589
569
590
assert self .state .fn is not None
591
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
592
+ download_model_from_registry (ckpt_path , self )
570
593
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
571
594
self .state .fn ,
572
595
ckpt_path ,
@@ -596,8 +619,8 @@ def validate(
596
619
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
597
620
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
598
621
599
- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
600
- If ``None`` and the model instance was passed, use the current weights.
622
+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
623
+ to validate. If ``None`` and the model instance was passed, use the current weights.
601
624
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
602
625
if a checkpoint callback is configured.
603
626
@@ -675,6 +698,8 @@ def _validate_impl(
675
698
self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
676
699
677
700
assert self .state .fn is not None
701
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
702
+ download_model_from_registry (ckpt_path , self )
678
703
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
679
704
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
680
705
)
@@ -705,8 +730,8 @@ def test(
705
730
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
706
731
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
707
732
708
- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
709
- If ``None`` and the model instance was passed, use the current weights.
733
+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
734
+ to test. If ``None`` and the model instance was passed, use the current weights.
710
735
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
711
736
if a checkpoint callback is configured.
712
737
@@ -784,6 +809,8 @@ def _test_impl(
784
809
self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
785
810
786
811
assert self .state .fn is not None
812
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
813
+ download_model_from_registry (ckpt_path , self )
787
814
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
788
815
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
789
816
)
@@ -820,8 +847,8 @@ def predict(
820
847
return_predictions: Whether to return predictions.
821
848
``True`` by default except when an accelerator that spawns processes is used (not supported).
822
849
823
- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
824
- If ``None`` and the model instance was passed, use the current weights.
850
+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
851
+ to predict. If ``None`` and the model instance was passed, use the current weights.
825
852
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
826
853
if a checkpoint callback is configured.
827
854
@@ -893,6 +920,8 @@ def _predict_impl(
893
920
self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
894
921
895
922
assert self .state .fn is not None
923
+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
924
+ download_model_from_registry (ckpt_path , self )
896
925
ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
897
926
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
898
927
)
0 commit comments