1919from typing import Optional , Union
2020
2121from lightning_utilities import module_available
22+ from lightning_utilities .core .imports import RequirementCache
2223
2324import lightning .pytorch as pl
2425from lightning .fabric .utilities .registry import _load_external_callbacks
@@ -93,7 +94,7 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
9394 " but found `ModelCheckpoint` in callbacks list."
9495 )
9596 elif enable_checkpointing :
96- if module_available ("litmodels" ) and self .trainer ._model_registry :
97+ if RequirementCache ("litmodels >=0.1.7 " ) and self .trainer ._model_registry :
9798 trainer_source = inspect .getmodule (self .trainer )
9899 if trainer_source is None or not isinstance (trainer_source .__package__ , str ):
99100 raise RuntimeError ("Unable to determine the source of the trainer." )
@@ -103,11 +104,11 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
103104 else :
104105 from litmodels .integrations .checkpoints import LightningModelCheckpoint as LitModelCheckpoint
105106
106- model_checkpoint = LitModelCheckpoint (model_name = self .trainer ._model_registry )
107+ model_checkpoint = LitModelCheckpoint (model_registry = self .trainer ._model_registry )
107108 else :
108109 rank_zero_info (
109110 "You are using the default ModelCheckpoint callback."
110- " Install `litmodels` package to use the `LitModelCheckpoint` instead"
111+ " Install `pip install litmodels` package to use the `LitModelCheckpoint` instead"
111112 " for seamless uploading to the Lightning model registry."
112113 )
113114 model_checkpoint = ModelCheckpoint ()
0 commit comments