|
3 | 3 | from io import StringIO |
4 | 4 |
|
5 | 5 | import pytest |
| 6 | +from lightning_sdk import Teamspace |
6 | 7 | from lightning_sdk.lightning_cloud.rest_client import GridRestClient |
7 | 8 | from lightning_sdk.utils.resolve import _resolve_teamspace |
8 | 9 | from litmodels import download_model, upload_model |
|
13 | 14 | LIT_TEAMSPACE = "LitModels" |
14 | 15 |
|
15 | 16 |
|
| 17 | +def _cleanup_model(teamspace: Teamspace, model_name: str) -> None: |
| 18 | + """Cleanup model from the teamspace.""" |
| 19 | + client = GridRestClient() |
| 20 | + # cleaning created models as each test run shall have unique model name |
| 21 | + model = client.models_store_get_model_by_name( |
| 22 | + project_owner_name=teamspace.owner.name, |
| 23 | + project_name=teamspace.name, |
| 24 | + model_name=model_name, |
| 25 | + ) |
| 26 | + client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) |
| 27 | + |
| 28 | + |
16 | 29 | @pytest.mark.cloud() |
17 | 30 | def test_upload_download_model(tmp_path): |
18 | 31 | """Verify that the model is uploaded to the teamspace""" |
@@ -44,14 +57,7 @@ def test_upload_download_model(tmp_path): |
44 | 57 | assert os.path.isfile(os.path.join(tmp_path, file)) |
45 | 58 |
|
46 | 59 | # CLEANING |
47 | | - client = GridRestClient() |
48 | | - # cleaning created models as each test run shall have unique model name |
49 | | - model = client.models_store_get_model_by_name( |
50 | | - project_owner_name=teamspace.owner.name, |
51 | | - project_name=teamspace.name, |
52 | | - model_name=model_name, |
53 | | - ) |
54 | | - client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) |
| 60 | + _cleanup_model(teamspace, model_name) |
55 | 61 |
|
56 | 62 |
|
57 | 63 | @pytest.mark.parametrize( |
@@ -84,11 +90,4 @@ def test_lightning_default_checkpointing(importing, tmp_path): |
84 | 90 | trainer.fit(BoringModel()) |
85 | 91 |
|
86 | 92 | # CLEANING |
87 | | - client = GridRestClient() |
88 | | - # cleaning created models as each test run shall have unique model name |
89 | | - model = client.models_store_get_model_by_name( |
90 | | - project_owner_name=teamspace.owner.name, |
91 | | - project_name=teamspace.name, |
92 | | - model_name=model_name, |
93 | | - ) |
94 | | - client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) |
| 93 | + _cleanup_model(teamspace, model_name) |
0 commit comments