|
1 | 1 | import os |
| 2 | +from contextlib import nullcontext |
2 | 3 | from unittest import mock |
3 | 4 |
|
4 | 5 | import joblib |
|
12 | 13 | from litmodels import download_model, load_model, upload_model |
13 | 14 | from litmodels.io import upload_model_files |
14 | 15 | from litmodels.io.utils import _KERAS_AVAILABLE |
15 | | - |
16 | | - |
17 | | -@pytest.mark.parametrize( |
18 | | - "name", ["/too/many/slashes", "org/model"] |
19 | | -) # todo: add back "model-name" after next SDk release |
20 | | -def test_upload_wrong_model_name(name): |
21 | | - with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"): |
| 16 | +from tests.integrations import LIT_ORG, LIT_TEAMSPACE |
| 17 | + |
| 18 | + |
| 19 | +@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"]) |
| 20 | +@pytest.mark.parametrize("in_studio", [True, False]) |
| 21 | +def test_upload_wrong_model_name(name, in_studio, monkeypatch): |
| 22 | + if in_studio: |
| 23 | + # mock env variables as it would run in studio |
| 24 | + monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) |
| 25 | + monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) |
| 26 | + monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock) |
| 27 | + monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock) |
| 28 | + monkeypatch.setattr("lightning_sdk.teamspace.TeamspaceApi", mock.MagicMock) |
| 29 | + monkeypatch.setattr("lightning_sdk.models._get_teamspace", mock.MagicMock) |
| 30 | + |
| 31 | + in_studio_only_name = in_studio and name == "model-name" |
| 32 | + with ( |
| 33 | + pytest.raises(ValueError, match=r".*organization/teamspace/model.*") |
| 34 | + if not in_studio_only_name |
| 35 | + else nullcontext() |
| 36 | + ): |
22 | 37 | upload_model_files(path="path/to/checkpoint", name=name) |
23 | 38 |
|
24 | 39 |
|
25 | | -@pytest.mark.parametrize( |
26 | | - "name", ["/too/many/slashes", "org/model"] |
27 | | -) # todo: add back "model-name" after next SDk release |
28 | | -def test_download_wrong_model_name(name): |
29 | | - with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"): |
| 40 | +@pytest.mark.parametrize("name", ["/too/many/slashes", "org/model", "model-name"]) |
| 41 | +@pytest.mark.parametrize("in_studio", [True, False]) |
| 42 | +def test_download_wrong_model_name(name, in_studio, monkeypatch): |
| 43 | + if in_studio: |
| 44 | + # mock env variables as it would run in studio |
| 45 | + monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) |
| 46 | + monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) |
| 47 | + monkeypatch.setattr("lightning_sdk.organization.Organization", mock.MagicMock) |
| 48 | + monkeypatch.setattr("lightning_sdk.teamspace.Teamspace", mock.MagicMock) |
| 49 | + monkeypatch.setattr("lightning_sdk.models.TeamspaceApi", mock.MagicMock) |
| 50 | + in_studio_only_name = in_studio and name == "model-name" |
| 51 | + with ( |
| 52 | + pytest.raises(ValueError, match=r".*organization/teamspace/model.*") |
| 53 | + if not in_studio_only_name |
| 54 | + else nullcontext() |
| 55 | + ): |
30 | 56 | download_model(name=name) |
31 | 57 |
|
32 | 58 |
|
|
0 commit comments