Skip to content

Commit 1c149bd

Browse files
committed
predict validate tests
1 parent 4133bba commit 1c149bd

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,6 @@ def _fit_impl(
578578
) -> None:
579579
log.debug(f"{self.__class__.__name__}: trainer fit stage")
580580

581-
if _is_registry(ckpt_path) and module_available("litmodels"):
582-
download_model_from_registry(ckpt_path, self)
583-
584581
# if a datamodule comes in as the second arg, then fix it for the user
585582
if isinstance(train_dataloaders, LightningDataModule):
586583
datamodule = train_dataloaders
@@ -597,6 +594,8 @@ def _fit_impl(
597594
)
598595

599596
assert self.state.fn is not None
597+
if _is_registry(ckpt_path) and module_available("litmodels"):
598+
download_model_from_registry(ckpt_path, self)
600599
ckpt_path = self._checkpoint_connector._select_ckpt_path(
601600
self.state.fn,
602601
ckpt_path,
@@ -626,8 +625,8 @@ def validate(
626625
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
627626
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
628627
629-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
630-
If ``None`` and the model instance was passed, use the current weights.
628+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
629+
to validate. If ``None`` and the model instance was passed, use the current weights.
631630
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
632631
if a checkpoint callback is configured.
633632
@@ -705,6 +704,8 @@ def _validate_impl(
705704
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
706705

707706
assert self.state.fn is not None
707+
if _is_registry(ckpt_path) and module_available("litmodels"):
708+
download_model_from_registry(ckpt_path, self)
708709
ckpt_path = self._checkpoint_connector._select_ckpt_path(
709710
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
710711
)
@@ -735,8 +736,8 @@ def test(
735736
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
736737
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
737738
738-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
739-
If ``None`` and the model instance was passed, use the current weights.
739+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
740+
to test. If ``None`` and the model instance was passed, use the current weights.
740741
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
741742
if a checkpoint callback is configured.
742743
@@ -814,6 +815,8 @@ def _test_impl(
814815
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
815816

816817
assert self.state.fn is not None
818+
if _is_registry(ckpt_path) and module_available("litmodels"):
819+
download_model_from_registry(ckpt_path, self)
817820
ckpt_path = self._checkpoint_connector._select_ckpt_path(
818821
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
819822
)
@@ -850,8 +853,8 @@ def predict(
850853
return_predictions: Whether to return predictions.
851854
``True`` by default except when an accelerator that spawns processes is used (not supported).
852855
853-
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
854-
If ``None`` and the model instance was passed, use the current weights.
856+
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
857+
to predict. If ``None`` and the model instance was passed, use the current weights.
855858
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
856859
if a checkpoint callback is configured.
857860
@@ -923,6 +926,8 @@ def _predict_impl(
923926
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
924927

925928
assert self.state.fn is not None
929+
if _is_registry(ckpt_path) and module_available("litmodels"):
930+
download_model_from_registry(ckpt_path, self)
926931
ckpt_path = self._checkpoint_connector._select_ckpt_path(
927932
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
928933
)

0 commit comments

Comments
 (0)