diff --git a/examples/train-model-with-lightning-callback.py b/examples/train-model-with-lightning-callback.py index 3bfdf3e..100ce9c 100644 --- a/examples/train-model-with-lightning-callback.py +++ b/examples/train-model-with-lightning-callback.py @@ -5,7 +5,7 @@ import torch.utils.data as data import torchvision as tv from lightning import Trainer -from litmodels.integrations.lightning_checkpoint import LitModelCheckpoint +from litmodels.integrations import LitModelCheckpoint from sample_model import LitAutoEncoder # Define the model name - this should be unique to your model diff --git a/src/litmodels/__about__.py b/src/litmodels/__about__.py index 441411f..5d0b459 100644 --- a/src/litmodels/__about__.py +++ b/src/litmodels/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.0.post0" +__version__ = "0.1.1" __author__ = "Lightning-AI et al." __author_email__ = "community@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litmodels/integrations/__init__.py b/src/litmodels/integrations/__init__.py index 9765da4..b8a1b71 100644 --- a/src/litmodels/integrations/__init__.py +++ b/src/litmodels/integrations/__init__.py @@ -1 +1,5 @@ """Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others.""" + +from litmodels.integrations.checkpoints import LitModelCheckpoint + +__all__ = ["LitModelCheckpoint"] diff --git a/src/litmodels/integrations/lightning_checkpoint.py b/src/litmodels/integrations/checkpoints.py similarity index 100% rename from src/litmodels/integrations/lightning_checkpoint.py rename to src/litmodels/integrations/checkpoints.py diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 549188e..11536c3 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -1,8 +1,8 @@ import re from unittest import mock +from litmodels.integrations.checkpoints import LitModelCheckpoint from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE -from litmodels.integrations.lightning_checkpoint import LitModelCheckpoint if _LIGHTNING_AVAILABLE: from lightning import Trainer @@ -13,7 +13,7 @@ @mock.patch("litmodels.io.cloud.sdk_upload_model") -@mock.patch("litmodels.integrations.lightning_checkpoint.Auth") +@mock.patch("litmodels.integrations.checkpoints.Auth") def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, tmp_path): mock_upload_model.return_value.name = "org-name/teamspace/model-name"