Skip to content

Commit 3183ae3

Browse files
test: validate scenario being in studio (#90)
* test: validate scenario being in studio * mocking --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1ac9424 commit 3183ae3

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

tests/test_io_cloud.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from contextlib import nullcontext
23
from unittest import mock
34

45
import joblib
@@ -12,21 +13,46 @@
1213
from litmodels import download_model, load_model, upload_model
1314
from litmodels.io import upload_model_files
1415
from litmodels.io.utils import _KERAS_AVAILABLE
15-
16-
17-
@pytest.mark.parametrize(
18-
"name", ["/too/many/slashes", "org/model"]
19-
) # todo: add back "model-name" after next SDk release
20-
def test_upload_wrong_model_name(name):
21-
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
16+
from tests.integrations import LIT_ORG, LIT_TEAMSPACE
17+
18+
19+
@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"])
20+
@pytest.mark.parametrize("in_studio", [True, False])
21+
def test_upload_wrong_model_name(name, in_studio, monkeypatch):
22+
if in_studio:
23+
# mock env variables as it would run in studio
24+
monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
25+
monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
26+
monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock)
27+
monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock)
28+
monkeypatch.setattr("lightning_sdk.teamspace.TeamspaceApi", mock.MagicMock)
29+
monkeypatch.setattr("lightning_sdk.models._get_teamspace", mock.MagicMock)
30+
31+
in_studio_only_name = in_studio and name == "model-name"
32+
with (
33+
pytest.raises(ValueError, match=r".*organization/teamspace/model.*")
34+
if not in_studio_only_name
35+
else nullcontext()
36+
):
2237
upload_model_files(path="path/to/checkpoint", name=name)
2338

2439

25-
@pytest.mark.parametrize(
26-
"name", ["/too/many/slashes", "org/model"]
27-
) # todo: add back "model-name" after next SDk release
28-
def test_download_wrong_model_name(name):
29-
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
40+
@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"])
41+
@pytest.mark.parametrize("in_studio", [True, False])
42+
def test_download_wrong_model_name(name, in_studio, monkeypatch):
43+
if in_studio:
44+
# mock env variables as it would run in studio
45+
monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
46+
monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
47+
monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock)
48+
monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock)
49+
monkeypatch.setattr("lightning_sdk.models.TeamspaceApi", mock.MagicMock)
50+
in_studio_only_name = in_studio and name == "model-name"
51+
with (
52+
pytest.raises(ValueError, match=r".*organization/teamspace/model.*")
53+
if not in_studio_only_name
54+
else nullcontext()
55+
):
3056
download_model(name=name)
3157

3258

0 commit comments

Comments
 (0)