|
3 | 3 | from io import StringIO |
4 | 4 |
|
5 | 5 | import pytest |
| 6 | +from lightning_sdk import Teamspace |
6 | 7 | from lightning_sdk.lightning_cloud.rest_client import GridRestClient |
7 | 8 | from lightning_sdk.utils.resolve import _resolve_teamspace |
8 | 9 | from litmodels import download_model, upload_model |
9 | 10 |
|
| 11 | +from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 |
| 12 | + |
10 | 13 | LIT_ORG = "lightning-ai" |
11 | 14 | LIT_TEAMSPACE = "LitModels" |
12 | 15 |
|
13 | 16 |
|
| 17 | +def _cleanup_model(teamspace: Teamspace, model_name: str) -> None: |
| 18 | + """Cleanup model from the teamspace.""" |
| 19 | + client = GridRestClient() |
| 20 | + # cleaning created models as each test run shall have unique model name |
| 21 | + model = client.models_store_get_model_by_name( |
| 22 | + project_owner_name=teamspace.owner.name, |
| 23 | + project_name=teamspace.name, |
| 24 | + model_name=model_name, |
| 25 | + ) |
| 26 | + client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) |
| 27 | + |
| 28 | + |
14 | 29 | @pytest.mark.cloud() |
15 | 30 | def test_upload_download_model(tmp_path): |
16 | 31 | """Verify that the model is uploaded to the teamspace""" |
@@ -41,11 +56,38 @@ def test_upload_download_model(tmp_path): |
41 | 56 | for file in model_files: |
42 | 57 | assert os.path.isfile(os.path.join(tmp_path, file)) |
43 | 58 |
|
44 | | - client = GridRestClient() |
45 | | - # cleaning created models with todo: also consider how to delete just this version of the model |
46 | | - model = client.models_store_get_model_by_name( |
47 | | - project_owner_name=teamspace.owner.name, |
48 | | - project_name=teamspace.name, |
49 | | - model_name=model_name, |
| 59 | + # CLEANING |
| 60 | + _cleanup_model(teamspace, model_name) |
| 61 | + |
| 62 | + |
| 63 | +@pytest.mark.parametrize( |
| 64 | + "importing", |
| 65 | + [ |
| 66 | + pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1), |
| 67 | + pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1), |
| 68 | + ], |
| 69 | +) |
| 70 | +@pytest.mark.cloud() |
| 71 | +# todo: mock env variables as it would run in studio |
| 72 | +def test_lightning_default_checkpointing(importing, tmp_path): |
| 73 | + if importing == "lightning": |
| 74 | + from lightning import Trainer |
| 75 | + from lightning.pytorch.demos.boring_classes import BoringModel |
| 76 | + elif importing == "pytorch_lightning": |
| 77 | + from pytorch_lightning import Trainer |
| 78 | + from pytorch_lightning.demos.boring_classes import BoringModel |
| 79 | + |
| 80 | + # model name with random hash |
| 81 | + model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}" |
| 82 | + teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) |
| 83 | + org_team = f"{teamspace.owner.name}/{teamspace.name}" |
| 84 | + |
| 85 | + trainer = Trainer( |
| 86 | + max_epochs=2, |
| 87 | + default_root_dir=tmp_path, |
| 88 | + model_registry=f"{org_team}/{model_name}", |
50 | 89 | ) |
51 | | - client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) |
| 90 | + trainer.fit(BoringModel()) |
| 91 | + |
| 92 | + # CLEANING |
| 93 | + _cleanup_model(teamspace, model_name) |
0 commit comments