Skip to content

Commit 360e477

Browse files
committed
show model URL after uploading
1 parent fc9c21e commit 360e477

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

src/litmodels/io/cloud.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5-
from typing import Optional, Tuple
5+
from typing import Optional, Tuple, Union
66

77
from lightning_sdk.api.teamspace_api import UploadedModelInfo
88
from lightning_sdk.teamspace import Teamspace
@@ -15,6 +15,8 @@
1515
# else:
1616
# LightningModule = None
1717

18+
_SHOWED_MODEL_LINKS = []
19+
1820

1921
def _parse_name(name: str) -> Tuple[str, str, str]:
2022
"""Parse the name argument into its components."""
@@ -50,11 +52,34 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
5052
return Teamspace(**teamspaces[requested_teamspace])
5153

5254

55+
def _print_model_link(org_name: str, teamspace_name: str, model_name: str, verbose: Union[bool, int]) -> None:
56+
"""Print a link to the uploaded model.
57+
58+
Args:
59+
org_name: Name of the organization.
60+
teamspace_name: Name of the teamspace.
61+
model_name: Name of the model.
62+
verbose: Whether to print the link:
63+
64+
- If set to 0, no link will be printed.
65+
- If set to 1, the link will be printed only once.
66+
- If set to 2, the link will be printed every time.
67+
"""
68+
url = f"https://lightning.ai/{org_name}/{teamspace_name}/models/{model_name}"
69+
msg = f"Model uploaded successfully. Link to the model: '{url}'"
70+
if int(verbose) > 1:
71+
print(msg)
72+
elif url not in _SHOWED_MODEL_LINKS:
73+
print(msg)
74+
_SHOWED_MODEL_LINKS.append(url)
75+
76+
5377
def upload_model_files(
5478
name: str,
5579
path: str,
5680
progress_bar: bool = True,
5781
cluster_id: Optional[str] = None,
82+
verbose: Union[bool, int] = 0,
5883
) -> UploadedModelInfo:
5984
"""Upload a local checkpoint file to the model store.
6085
@@ -65,16 +90,20 @@ def upload_model_files(
6590
progress_bar: Whether to show a progress bar for the upload.
6691
cluster_id: The name of the cluster to use. Only required if it can't be determined
6792
automatically.
93+
verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed.
6894
6995
"""
7096
org_name, teamspace_name, model_name = _parse_name(name)
7197
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
72-
return teamspace.upload_model(
98+
info = teamspace.upload_model(
7399
path=path,
74100
name=model_name,
75101
progress_bar=progress_bar,
76102
cluster_id=cluster_id,
77103
)
104+
if verbose:
105+
_print_model_link(org_name, teamspace_name, model_name, verbose)
106+
return info
78107

79108

80109
def download_model_files(

0 commit comments

Comments
 (0)