@@ -578,9 +578,6 @@ def _fit_impl(
578578 ) -> None :
579579 log .debug (f"{ self .__class__ .__name__ } : trainer fit stage" )
580580
581- if _is_registry (ckpt_path ) and module_available ("litmodels" ):
582- download_model_from_registry (ckpt_path , self )
583-
584581 # if a datamodule comes in as the second arg, then fix it for the user
585582 if isinstance (train_dataloaders , LightningDataModule ):
586583 datamodule = train_dataloaders
@@ -597,6 +594,8 @@ def _fit_impl(
597594 )
598595
599596 assert self .state .fn is not None
597+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
598+ download_model_from_registry (ckpt_path , self )
600599 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
601600 self .state .fn ,
602601 ckpt_path ,
@@ -626,8 +625,8 @@ def validate(
626625 Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
627626 the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
628627
629- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
630- If ``None`` and the model instance was passed, use the current weights.
628+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
629+ to validate. If ``None`` and the model instance was passed, use the current weights.
631630 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
632631 if a checkpoint callback is configured.
633632
@@ -705,6 +704,8 @@ def _validate_impl(
705704 self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
706705
707706 assert self .state .fn is not None
707+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
708+ download_model_from_registry (ckpt_path , self )
708709 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
709710 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
710711 )
@@ -735,8 +736,8 @@ def test(
735736 Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
736737 the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
737738
738- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
739- If ``None`` and the model instance was passed, use the current weights.
739+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
740+ to test. If ``None`` and the model instance was passed, use the current weights.
740741 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
741742 if a checkpoint callback is configured.
742743
@@ -814,6 +815,8 @@ def _test_impl(
814815 self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
815816
816817 assert self .state .fn is not None
818+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
819+ download_model_from_registry (ckpt_path , self )
817820 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
818821 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
819822 )
@@ -850,8 +853,8 @@ def predict(
850853 return_predictions: Whether to return predictions.
851854 ``True`` by default except when an accelerator that spawns processes is used (not supported).
852855
853- ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
854- If ``None`` and the model instance was passed, use the current weights.
856+ ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
857+ to predict. If ``None`` and the model instance was passed, use the current weights.
855858 Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
856859 if a checkpoint callback is configured.
857860
@@ -923,6 +926,8 @@ def _predict_impl(
923926 self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
924927
925928 assert self .state .fn is not None
929+ if _is_registry (ckpt_path ) and module_available ("litmodels" ):
930+ download_model_from_registry (ckpt_path , self )
926931 ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
927932 self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
928933 )
0 commit comments