Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions tests/test_io_cloud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from contextlib import nullcontext
from unittest import mock

import joblib
Expand All @@ -12,21 +13,46 @@
from litmodels import download_model, load_model, upload_model
from litmodels.io import upload_model_files
from litmodels.io.utils import _KERAS_AVAILABLE


@pytest.mark.parametrize(
"name", ["/too/many/slashes", "org/model"]
) # todo: add back "model-name" after next SDk release
def test_upload_wrong_model_name(name):
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
from tests.integrations import LIT_ORG, LIT_TEAMSPACE


@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"])
@pytest.mark.parametrize("in_studio", [True, False])
def test_upload_wrong_model_name(name, in_studio, monkeypatch):
if in_studio:
# mock env variables as it would run in studio
monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock)
monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock)
monkeypatch.setattr("lightning_sdk.teamspace.TeamspaceApi", mock.MagicMock)
monkeypatch.setattr("lightning_sdk.models._get_teamspace", mock.MagicMock)

in_studio_only_name = in_studio and name == "model-name"
with (
pytest.raises(ValueError, match=r".*organization/teamspace/model.*")
if not in_studio_only_name
else nullcontext()
):
upload_model_files(path="path/to/checkpoint", name=name)


@pytest.mark.parametrize(
"name", ["/too/many/slashes", "org/model"]
) # todo: add back "model-name" after next SDk release
def test_download_wrong_model_name(name):
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"])
@pytest.mark.parametrize("in_studio", [True, False])
def test_download_wrong_model_name(name, in_studio, monkeypatch):
if in_studio:
# mock env variables as it would run in studio
monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock)
monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock)
monkeypatch.setattr("lightning_sdk.models.TeamspaceApi", mock.MagicMock)
in_studio_only_name = in_studio and name == "model-name"
with (
pytest.raises(ValueError, match=r".*organization/teamspace/model.*")
if not in_studio_only_name
else nullcontext()
):
download_model(name=name)


Expand Down
Loading