diff --git a/src/litmodels/integrations/__init__.py b/src/litmodels/integrations/__init__.py index c0bbf06..fa18f1b 100644 --- a/src/litmodels/integrations/__init__.py +++ b/src/litmodels/integrations/__init__.py @@ -10,6 +10,6 @@ __all__ += ["LightningModelCheckpoint"] if _PYTORCHLIGHTNING_AVAILABLE: - from litmodels.integrations.checkpoints import PTLightningModelCheckpoint + from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint - __all__ += ["PTLightningModelCheckpoint"] + __all__ += ["PytorchLightningModelCheckpoint"] diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 9b5aebb..2d66e3d 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -67,7 +67,7 @@ def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None: if _PYTORCHLIGHTNING_AVAILABLE: - class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint): + class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint): """PyTorch Lightning ModelCheckpoint with LitModel support. Args: diff --git a/src/litmodels/integrations/imports.py b/src/litmodels/integrations/imports.py index ddefb18..a9a0e2c 100644 --- a/src/litmodels/integrations/imports.py +++ b/src/litmodels/integrations/imports.py @@ -1,4 +1,8 @@ -from lightning_utilities import module_available +import operator + +from lightning_utilities import compare_version, module_available _LIGHTNING_AVAILABLE = module_available("lightning") +_LIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("lightning", operator.ge, "2.5.1") _PYTORCHLIGHTNING_AVAILABLE = module_available("pytorch_lightning") +_PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("pytorch_lightning", operator.ge, "2.5.1") diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py index c87963b..82737f0 100644 --- a/tests/integrations/__init__.py +++ b/tests/integrations/__init__.py @@ -1,7 +1,18 @@ import pytest -from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE +from litmodels.integrations.imports import ( + _LIGHTNING_AVAILABLE, + _LIGHTNING_GREATER_EQUAL_2_5_1, + _PYTORCHLIGHTNING_AVAILABLE, + _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, +) _SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available") +_SKIP_IF_LIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif( + not _LIGHTNING_GREATER_EQUAL_2_5_1, reason="Lightning without integration introduced in 2.5.1" +) _SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif( not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available" ) +_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif( + not _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, reason="PyTorch Lightning without integration introduced in 2.5.1" +) diff --git a/tests/integrations/test_checkpoints.py b/tests/integrations/test_checkpoints.py index 5a3a100..f19ad57 100644 --- a/tests/integrations/test_checkpoints.py +++ b/tests/integrations/test_checkpoints.py @@ -23,7 +23,7 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing, from lightning.pytorch.demos.boring_classes import BoringModel from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint elif importing == "pytorch_lightning": - from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint + from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel @@ -64,7 +64,7 @@ def test_lightning_checkpointing_pickleable(mock_auth, importing): if importing == "lightning": from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint elif importing == "pytorch_lightning": - from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint + from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name") pickle.dumps(ckpt) diff --git a/tests/integrations/test_cloud.py b/tests/integrations/test_cloud.py index 01deafb..d0600e6 100644 --- a/tests/integrations/test_cloud.py +++ b/tests/integrations/test_cloud.py @@ -3,14 +3,29 @@ from io import StringIO import pytest +from lightning_sdk import Teamspace from lightning_sdk.lightning_cloud.rest_client import GridRestClient from lightning_sdk.utils.resolve import _resolve_teamspace from litmodels import download_model, upload_model +from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 + LIT_ORG = "lightning-ai" LIT_TEAMSPACE = "LitModels" +def _cleanup_model(teamspace: Teamspace, model_name: str) -> None: + """Cleanup model from the teamspace.""" + client = GridRestClient() + # cleaning created models as each test run shall have unique model name + model = client.models_store_get_model_by_name( + project_owner_name=teamspace.owner.name, + project_name=teamspace.name, + model_name=model_name, + ) + client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) + + @pytest.mark.cloud() def test_upload_download_model(tmp_path): """Verify that the model is uploaded to the teamspace""" @@ -41,11 +56,38 @@ def test_upload_download_model(tmp_path): for file in model_files: assert os.path.isfile(os.path.join(tmp_path, file)) - client = GridRestClient() - # cleaning created models with todo: also consider how to delete just this version of the model - model = client.models_store_get_model_by_name( - project_owner_name=teamspace.owner.name, - project_name=teamspace.name, - model_name=model_name, + # CLEANING + _cleanup_model(teamspace, model_name) + + +@pytest.mark.parametrize( + "importing", + [ + pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1), + pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1), + ], +) +@pytest.mark.cloud() +# todo: mock env variables as it would run in studio +def test_lightning_default_checkpointing(importing, tmp_path): + if importing == "lightning": + from lightning import Trainer + from lightning.pytorch.demos.boring_classes import BoringModel + elif importing == "pytorch_lightning": + from pytorch_lightning import Trainer + from pytorch_lightning.demos.boring_classes import BoringModel + + # model name with random hash + model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}" + teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) + org_team = f"{teamspace.owner.name}/{teamspace.name}" + + trainer = Trainer( + max_epochs=2, + default_root_dir=tmp_path, + model_registry=f"{org_team}/{model_name}", ) - client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) + trainer.fit(BoringModel()) + + # CLEANING + _cleanup_model(teamspace, model_name)