Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/litmodels/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
__all__ += ["LightningModelCheckpoint"]

if _PYTORCHLIGHTNING_AVAILABLE:
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint

__all__ += ["PTLightningModelCheckpoint"]
__all__ += ["PytorchLightningModelCheckpoint"]
2 changes: 1 addition & 1 deletion src/litmodels/integrations/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:

if _PYTORCHLIGHTNING_AVAILABLE:

class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
"""PyTorch Lightning ModelCheckpoint with LitModel support.

Args:
Expand Down
6 changes: 5 additions & 1 deletion src/litmodels/integrations/imports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from lightning_utilities import module_available
import operator

from lightning_utilities import compare_version, module_available

_LIGHTNING_AVAILABLE = module_available("lightning")
_LIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("lightning", operator.ge, "2.5.1")
_PYTORCHLIGHTNING_AVAILABLE = module_available("pytorch_lightning")
_PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("pytorch_lightning", operator.ge, "2.5.1")
13 changes: 12 additions & 1 deletion tests/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import pytest
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
from litmodels.integrations.imports import (
_LIGHTNING_AVAILABLE,
_LIGHTNING_GREATER_EQUAL_2_5_1,
_PYTORCHLIGHTNING_AVAILABLE,
_PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1,
)

_SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")
_SKIP_IF_LIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif(
not _LIGHTNING_GREATER_EQUAL_2_5_1, reason="Lightning without integration introduced in 2.5.1"
)
_SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif(
not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"
)
_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif(
not _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, reason="PyTorch Lightning without integration introduced in 2.5.1"
)
4 changes: 2 additions & 2 deletions tests/integrations/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing,
from lightning.pytorch.demos.boring_classes import BoringModel
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
elif importing == "pytorch_lightning":
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_lightning_checkpointing_pickleable(mock_auth, importing):
if importing == "lightning":
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
elif importing == "pytorch_lightning":
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint

ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
pickle.dumps(ckpt)
56 changes: 49 additions & 7 deletions tests/integrations/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,29 @@
from io import StringIO

import pytest
from lightning_sdk import Teamspace
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
from lightning_sdk.utils.resolve import _resolve_teamspace
from litmodels import download_model, upload_model

from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1

LIT_ORG = "lightning-ai"
LIT_TEAMSPACE = "LitModels"


def _cleanup_model(teamspace: Teamspace, model_name: str) -> None:
"""Cleanup model from the teamspace."""
client = GridRestClient()
# cleaning created models as each test run shall have unique model name
model = client.models_store_get_model_by_name(
project_owner_name=teamspace.owner.name,
project_name=teamspace.name,
model_name=model_name,
)
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)


@pytest.mark.cloud()
def test_upload_download_model(tmp_path):
"""Verify that the model is uploaded to the teamspace"""
Expand Down Expand Up @@ -41,11 +56,38 @@ def test_upload_download_model(tmp_path):
for file in model_files:
assert os.path.isfile(os.path.join(tmp_path, file))

client = GridRestClient()
# cleaning created models with todo: also consider how to delete just this version of the model
model = client.models_store_get_model_by_name(
project_owner_name=teamspace.owner.name,
project_name=teamspace.name,
model_name=model_name,
# CLEANING
_cleanup_model(teamspace, model_name)


@pytest.mark.parametrize(
"importing",
[
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
],
)
@pytest.mark.cloud()
# todo: mock env variables as it would run in studio
def test_lightning_default_checkpointing(importing, tmp_path):
if importing == "lightning":
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
elif importing == "pytorch_lightning":
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel

# model name with random hash
model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}"
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
org_team = f"{teamspace.owner.name}/{teamspace.name}"

trainer = Trainer(
max_epochs=2,
default_root_dir=tmp_path,
model_registry=f"{org_team}/{model_name}",
)
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)
trainer.fit(BoringModel())

# CLEANING
_cleanup_model(teamspace, model_name)
Loading