diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index ed8c2b7b5ab53..a0aaafa74b1af 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -18,7 +18,7 @@ from datetime import timedelta from typing import Optional, Union -from lightning_utilities import module_available +from lightning_utilities.core.imports import RequirementCache import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks @@ -93,7 +93,7 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None: " but found `ModelCheckpoint` in callbacks list." ) elif enable_checkpointing: - if module_available("litmodels") and self.trainer._model_registry: + if RequirementCache("litmodels >=0.1.7") and self.trainer._model_registry: trainer_source = inspect.getmodule(self.trainer) if trainer_source is None or not isinstance(trainer_source.__package__, str): raise RuntimeError("Unable to determine the source of the trainer.") @@ -103,11 +103,11 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None: else: from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint - model_checkpoint = LitModelCheckpoint(model_name=self.trainer._model_registry) + model_checkpoint = LitModelCheckpoint(model_registry=self.trainer._model_registry) else: rank_zero_info( "You are using the default ModelCheckpoint callback." - " Install `litmodels` package to use the `LitModelCheckpoint` instead" + " Install `pip install litmodels` package to use the `LitModelCheckpoint` instead" " for seamless uploading to the Lightning model registry." ) model_checkpoint = ModelCheckpoint()