Skip to content

Commit 6e56c27

Browse files
authored
CTA: show model URL after uploading (#21)
1 parent 59ef9b7 commit 6e56c27

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

src/litmodels/io/cloud.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5-
from typing import Optional, Tuple
5+
from typing import Optional, Tuple, Union
66

77
from lightning_sdk.api.teamspace_api import UploadedModelInfo
8+
from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
89
from lightning_sdk.teamspace import Teamspace
910
from lightning_sdk.utils import resolve as sdk_resolvers
1011

@@ -15,6 +16,8 @@
1516
# else:
1617
# LightningModule = None
1718

19+
_SHOWED_MODEL_LINKS = []
20+
1821

1922
def _parse_name(name: str) -> Tuple[str, str, str]:
2023
"""Parse the name argument into its components."""
@@ -50,11 +53,34 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
5053
return Teamspace(**teamspaces[requested_teamspace])
5154

5255

56+
def _print_model_link(org_name: str, teamspace_name: str, model_name: str, verbose: Union[bool, int]) -> None:
57+
"""Print a link to the uploaded model.
58+
59+
Args:
60+
org_name: Name of the organization.
61+
teamspace_name: Name of the teamspace.
62+
model_name: Name of the model.
63+
verbose: Whether to print the link:
64+
65+
- If set to 0, no link will be printed.
66+
- If set to 1, the link will be printed only once.
67+
- If set to 2, the link will be printed every time.
68+
"""
69+
url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}"
70+
msg = f"Model uploaded successfully. Link to the model: '{url}'"
71+
if int(verbose) > 1:
72+
print(msg)
73+
elif url not in _SHOWED_MODEL_LINKS:
74+
print(msg)
75+
_SHOWED_MODEL_LINKS.append(url)
76+
77+
5378
def upload_model_files(
5479
name: str,
5580
path: str,
5681
progress_bar: bool = True,
5782
cluster_id: Optional[str] = None,
83+
verbose: Union[bool, int] = 0,
5884
) -> UploadedModelInfo:
5985
"""Upload a local checkpoint file to the model store.
6086
@@ -65,16 +91,20 @@ def upload_model_files(
6591
progress_bar: Whether to show a progress bar for the upload.
6692
cluster_id: The name of the cluster to use. Only required if it can't be determined
6793
automatically.
94+
verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed.
6895
6996
"""
7097
org_name, teamspace_name, model_name = _parse_name(name)
7198
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
72-
return teamspace.upload_model(
99+
info = teamspace.upload_model(
73100
path=path,
74101
name=model_name,
75102
progress_bar=progress_bar,
76103
cluster_id=cluster_id,
77104
)
105+
if verbose:
106+
_print_model_link(org_name, teamspace_name, model_name, verbose)
107+
return info
78108

79109

80110
def download_model_files(

src/litmodels/io/gateway.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def upload_model(
2121
progress_bar: bool = True,
2222
cluster_id: Optional[str] = None,
2323
staging_dir: Optional[str] = None,
24+
verbose: Union[bool, int] = 0,
2425
) -> UploadedModelInfo:
2526
"""Upload a checkpoint to the model store.
2627
@@ -33,6 +34,7 @@ def upload_model(
3334
automatically.
3435
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
3536
be created and used.
37+
verbose: Whether to print some additional information about the uploaded model.
3638
3739
"""
3840
if not staging_dir:
@@ -54,6 +56,7 @@ def upload_model(
5456
name=name,
5557
progress_bar=progress_bar,
5658
cluster_id=cluster_id,
59+
verbose=verbose,
5760
)
5861

5962

tests/test_io_cloud.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ def test_wrong_model_name(name):
1616

1717

1818
@pytest.mark.parametrize(
19-
("model", "model_path"),
19+
("model", "model_path", "verbose"),
2020
[
21-
("path/to/checkpoint", "path/to/checkpoint"),
21+
("path/to/checkpoint", "path/to/checkpoint", False),
2222
# (BoringModel(), "%s/BoringModel.ckpt"),
23-
(Module(), f"%s{os.path.sep}Module.pth"),
23+
(Module(), f"%s{os.path.sep}Module.pth", True),
2424
],
2525
)
26-
def test_upload_model(mocker, tmpdir, model, model_path):
26+
def test_upload_model(mocker, tmpdir, model, model_path, verbose):
2727
# mocking the _get_teamspace to return another mock
2828
ts_mock = mock.MagicMock()
2929
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
@@ -34,6 +34,7 @@ def test_upload_model(mocker, tmpdir, model, model_path):
3434
name="org-name/teamspace/model-name",
3535
cluster_id="cluster_id",
3636
staging_dir=tmpdir,
37+
verbose=verbose,
3738
)
3839
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
3940
ts_mock.upload_model.assert_called_once_with(

0 commit comments

Comments
 (0)