Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/litmodels/io/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers

Expand All @@ -15,6 +16,8 @@
# else:
# LightningModule = None

_SHOWED_MODEL_LINKS = []


def _parse_name(name: str) -> Tuple[str, str, str]:
"""Parse the name argument into its components."""
Expand Down Expand Up @@ -50,11 +53,34 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
return Teamspace(**teamspaces[requested_teamspace])


def _print_model_link(org_name: str, teamspace_name: str, model_name: str, verbose: Union[bool, int]) -> None:
"""Print a link to the uploaded model.

Args:
org_name: Name of the organization.
teamspace_name: Name of the teamspace.
model_name: Name of the model.
verbose: Whether to print the link:

- If set to 0, no link will be printed.
- If set to 1, the link will be printed only once.
- If set to 2, the link will be printed every time.
"""
url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}"
msg = f"Model uploaded successfully. Link to the model: '{url}'"
if int(verbose) > 1:
print(msg)
elif url not in _SHOWED_MODEL_LINKS:
print(msg)
_SHOWED_MODEL_LINKS.append(url)


def upload_model_files(
name: str,
path: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
verbose: Union[bool, int] = 0,
) -> UploadedModelInfo:
"""Upload a local checkpoint file to the model store.

Expand All @@ -65,16 +91,20 @@ def upload_model_files(
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.
verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed.

"""
org_name, teamspace_name, model_name = _parse_name(name)
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
return teamspace.upload_model(
info = teamspace.upload_model(
path=path,
name=model_name,
progress_bar=progress_bar,
cluster_id=cluster_id,
)
if verbose:
_print_model_link(org_name, teamspace_name, model_name, verbose)
return info


def download_model_files(
Expand Down
3 changes: 3 additions & 0 deletions src/litmodels/io/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def upload_model(
progress_bar: bool = True,
cluster_id: Optional[str] = None,
staging_dir: Optional[str] = None,
verbose: Union[bool, int] = 0,
) -> UploadedModelInfo:
"""Upload a checkpoint to the model store.

Expand All @@ -33,6 +34,7 @@ def upload_model(
automatically.
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
be created and used.
verbose: Whether to print some additional information about the uploaded model.

"""
if not staging_dir:
Expand All @@ -54,6 +56,7 @@ def upload_model(
name=name,
progress_bar=progress_bar,
cluster_id=cluster_id,
verbose=verbose,
)


Expand Down
9 changes: 5 additions & 4 deletions tests/test_io_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def test_wrong_model_name(name):


@pytest.mark.parametrize(
("model", "model_path"),
("model", "model_path", "verbose"),
[
("path/to/checkpoint", "path/to/checkpoint"),
("path/to/checkpoint", "path/to/checkpoint", False),
# (BoringModel(), "%s/BoringModel.ckpt"),
(Module(), f"%s{os.path.sep}Module.pth"),
(Module(), f"%s{os.path.sep}Module.pth", True),
],
)
def test_upload_model(mocker, tmpdir, model, model_path):
def test_upload_model(mocker, tmpdir, model, model_path, verbose):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
Expand All @@ -34,6 +34,7 @@ def test_upload_model(mocker, tmpdir, model, model_path):
name="org-name/teamspace/model-name",
cluster_id="cluster_id",
staging_dir=tmpdir,
verbose=verbose,
)
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
ts_mock.upload_model.assert_called_once_with(
Expand Down
Loading