Skip to content

Commit 7aa2912

Browse files
authored
test: integration Trainer's default ckpt with Prod (#59)
* integration Trainer's default ckpt with Prod * tmp & _cleanup_model * `PytorchLightningModelCheckpoint`
1 parent 54885c6 commit 7aa2912

File tree

6 files changed

+71
-14
lines changed

6 files changed

+71
-14
lines changed

src/litmodels/integrations/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
__all__ += ["LightningModelCheckpoint"]
1111

1212
if _PYTORCHLIGHTNING_AVAILABLE:
13-
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint
13+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint
1414

15-
__all__ += ["PTLightningModelCheckpoint"]
15+
__all__ += ["PytorchLightningModelCheckpoint"]

src/litmodels/integrations/checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
6767

6868
if _PYTORCHLIGHTNING_AVAILABLE:
6969

70-
class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
70+
class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
7171
"""PyTorch Lightning ModelCheckpoint with LitModel support.
7272
7373
Args:
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from lightning_utilities import module_available
1+
import operator
2+
3+
from lightning_utilities import compare_version, module_available
24

35
_LIGHTNING_AVAILABLE = module_available("lightning")
6+
_LIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("lightning", operator.ge, "2.5.1")
47
_PYTORCHLIGHTNING_AVAILABLE = module_available("pytorch_lightning")
8+
_PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1 = compare_version("pytorch_lightning", operator.ge, "2.5.1")

tests/integrations/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
import pytest
2-
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
2+
from litmodels.integrations.imports import (
3+
_LIGHTNING_AVAILABLE,
4+
_LIGHTNING_GREATER_EQUAL_2_5_1,
5+
_PYTORCHLIGHTNING_AVAILABLE,
6+
_PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1,
7+
)
38

49
_SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")
10+
_SKIP_IF_LIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif(
11+
not _LIGHTNING_GREATER_EQUAL_2_5_1, reason="Lightning without integration introduced in 2.5.1"
12+
)
513
_SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif(
614
not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"
715
)
16+
_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif(
17+
not _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, reason="PyTorch Lightning without integration introduced in 2.5.1"
18+
)

tests/integrations/test_checkpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing,
2323
from lightning.pytorch.demos.boring_classes import BoringModel
2424
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
2525
elif importing == "pytorch_lightning":
26-
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
26+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
2727
from pytorch_lightning import Trainer
2828
from pytorch_lightning.callbacks import ModelCheckpoint
2929
from pytorch_lightning.demos.boring_classes import BoringModel
@@ -64,7 +64,7 @@ def test_lightning_checkpointing_pickleable(mock_auth, importing):
6464
if importing == "lightning":
6565
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
6666
elif importing == "pytorch_lightning":
67-
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
67+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
6868

6969
ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
7070
pickle.dumps(ckpt)

tests/integrations/test_cloud.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,29 @@
33
from io import StringIO
44

55
import pytest
6+
from lightning_sdk import Teamspace
67
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
78
from lightning_sdk.utils.resolve import _resolve_teamspace
89
from litmodels import download_model, upload_model
910

11+
from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1
12+
1013
LIT_ORG = "lightning-ai"
1114
LIT_TEAMSPACE = "LitModels"
1215

1316

17+
def _cleanup_model(teamspace: Teamspace, model_name: str) -> None:
18+
"""Cleanup model from the teamspace."""
19+
client = GridRestClient()
20+
# cleaning created models as each test run shall have unique model name
21+
model = client.models_store_get_model_by_name(
22+
project_owner_name=teamspace.owner.name,
23+
project_name=teamspace.name,
24+
model_name=model_name,
25+
)
26+
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)
27+
28+
1429
@pytest.mark.cloud()
1530
def test_upload_download_model(tmp_path):
1631
"""Verify that the model is uploaded to the teamspace"""
@@ -41,11 +56,38 @@ def test_upload_download_model(tmp_path):
4156
for file in model_files:
4257
assert os.path.isfile(os.path.join(tmp_path, file))
4358

44-
client = GridRestClient()
45-
# cleaning created models with todo: also consider how to delete just this version of the model
46-
model = client.models_store_get_model_by_name(
47-
project_owner_name=teamspace.owner.name,
48-
project_name=teamspace.name,
49-
model_name=model_name,
59+
# CLEANING
60+
_cleanup_model(teamspace, model_name)
61+
62+
63+
@pytest.mark.parametrize(
64+
"importing",
65+
[
66+
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
67+
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
68+
],
69+
)
70+
@pytest.mark.cloud()
71+
# todo: mock env variables as it would run in studio
72+
def test_lightning_default_checkpointing(importing, tmp_path):
73+
if importing == "lightning":
74+
from lightning import Trainer
75+
from lightning.pytorch.demos.boring_classes import BoringModel
76+
elif importing == "pytorch_lightning":
77+
from pytorch_lightning import Trainer
78+
from pytorch_lightning.demos.boring_classes import BoringModel
79+
80+
# model name with random hash
81+
model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}"
82+
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
83+
org_team = f"{teamspace.owner.name}/{teamspace.name}"
84+
85+
trainer = Trainer(
86+
max_epochs=2,
87+
default_root_dir=tmp_path,
88+
model_registry=f"{org_team}/{model_name}",
5089
)
51-
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)
90+
trainer.fit(BoringModel())
91+
92+
# CLEANING
93+
_cleanup_model(teamspace, model_name)

0 commit comments

Comments
 (0)