Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 14 additions & 43 deletions src/litmodels/io/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers
from lightning_sdk.models import download_model, upload_model

if TYPE_CHECKING:
from lightning_sdk.models import UploadedModelInfo
Expand All @@ -32,42 +31,18 @@ def _parse_name(name: str) -> Tuple[str, str, str]:
return org_name, teamspace_name, model_name


def _get_teamspace(name: str, organization: str) -> Teamspace:
"""Get a Teamspace object from the SDK."""
from lightning_sdk.api import OrgApi, UserApi

org_api = OrgApi()
user = sdk_resolvers._get_authed_user()
teamspaces = {}
for ts in UserApi()._get_all_teamspace_memberships(""):
if ts.owner_type == "organization":
org = org_api._get_org_by_id(ts.owner_id)
teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name}
elif ts.owner_type == "user": # todo: check also the name
teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user}
else:
raise RuntimeError(f"Unknown organization type {ts.organization_type}")

requested_teamspace = f"{organization}/{name}".lower()
if requested_teamspace not in teamspaces:
options = "\n\t".join(teamspaces.keys())
raise RuntimeError(f"Teamspace `{requested_teamspace}` not found. Available teamspaces: \n\t{options}")
return Teamspace(**teamspaces[requested_teamspace])


def _print_model_link(org_name: str, teamspace_name: str, model_name: str, verbose: Union[bool, int]) -> None:
def _print_model_link(name: str, verbose: Union[bool, int]) -> None:
"""Print a link to the uploaded model.

Args:
org_name: Name of the organization.
teamspace_name: Name of the teamspace.
model_name: Name of the model.
name: Name of the model.
verbose: Whether to print the link:

- If set to 0, no link will be printed.
- If set to 1, the link will be printed only once.
- If set to 2, the link will be printed every time.
"""
org_name, teamspace_name, model_name = _parse_name(name)
url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}"
msg = f"Model uploaded successfully. Link to the model: '{url}'"
if int(verbose) > 1:
Expand All @@ -81,7 +56,7 @@ def upload_model_files(
name: str,
path: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
cloud_account: Optional[str] = None,
verbose: Union[bool, int] = 1,
) -> "UploadedModelInfo":
"""Upload a local checkpoint file to the model store.
Expand All @@ -91,29 +66,27 @@ def upload_model_files(
where entity is either your username or the name of an organization you are part of.
path: Path to the model file to upload.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
automatically.
verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed.

"""
org_name, teamspace_name, model_name = _parse_name(name)
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
info = teamspace.upload_model(
info = upload_model(
name=name,
path=path,
name=model_name,
progress_bar=progress_bar,
cluster_id=cluster_id,
cloud_account=cloud_account,
)
if verbose:
_print_model_link(org_name, teamspace_name, model_name, verbose)
_print_model_link(info.name, verbose)
return info


def download_model_files(
name: str,
download_dir: str = ".",
progress_bar: bool = True,
) -> str:
) -> Union[str, List[str]]:
"""Download a checkpoint from the model store.

Args:
Expand All @@ -126,10 +99,8 @@ def download_model_files(
Returns:
The absolute path to the downloaded model file or folder.
"""
org_name, teamspace_name, model_name = _parse_name(name)
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
return teamspace.download_model(
name=model_name,
return download_model(
name=name,
download_dir=download_dir,
progress_bar=progress_bar,
)
10 changes: 5 additions & 5 deletions src/litmodels/io/gateway.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

from lightning_utilities import module_available

Expand All @@ -21,7 +21,7 @@ def upload_model(
name: str,
model: Union[str, Path, "Module"],
progress_bar: bool = True,
cluster_id: Optional[str] = None,
cloud_account: Optional[str] = None,
staging_dir: Optional[str] = None,
verbose: Union[bool, int] = 1,
) -> "UploadedModelInfo":
Expand All @@ -32,7 +32,7 @@ def upload_model(
where entity is either your username or the name of an organization you are part of.
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
automatically.
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
be created and used.
Expand All @@ -57,7 +57,7 @@ def upload_model(
path=path,
name=name,
progress_bar=progress_bar,
cluster_id=cluster_id,
cloud_account=cloud_account,
verbose=verbose,
)

Expand All @@ -66,7 +66,7 @@ def download_model(
name: str,
download_dir: str = ".",
progress_bar: bool = True,
) -> str:
) -> Union[str, List[str]]:
"""Download a checkpoint from the model store.

Args:
Expand Down
25 changes: 11 additions & 14 deletions tests/test_io_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,34 @@ def test_wrong_model_name(name):
(Module(), f"%s{os.path.sep}Module.pth", True),
],
)
def test_upload_model(mocker, tmpdir, model, model_path, verbose):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
@mock.patch("litmodels.io.cloud.upload_model")
def test_upload_model(mock_upload_model, tmpdir, model, model_path, verbose):
mock_upload_model.return_value.name = "org-name/teamspace/model-name"

# The lit-logger function is just a wrapper around the SDK function
upload_model(
model=model,
name="org-name/teamspace/model-name",
cluster_id="cluster_id",
cloud_account="cluster_id",
staging_dir=tmpdir,
verbose=verbose,
)
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
ts_mock.upload_model.assert_called_once_with(
mock_upload_model.assert_called_once_with(
path=expected_path,
name="model-name",
cluster_id="cluster_id",
name="org-name/teamspace/model-name",
cloud_account="cluster_id",
progress_bar=True,
)


def test_download_model(mocker):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
@mock.patch("litmodels.io.cloud.download_model")
def test_download_model(mock_download_model):
# The lit-logger function is just a wrapper around the SDK function
download_model(
name="org-name/teamspace/model-name",
download_dir="where/to/download",
)
ts_mock.download_model.assert_called_once_with(
name="model-name", download_dir="where/to/download", progress_bar=True
mock_download_model.assert_called_once_with(
name="org-name/teamspace/model-name", download_dir="where/to/download", progress_bar=True
)
Loading