diff --git a/src/litmodels/io/cloud.py b/src/litmodels/io/cloud.py index 79cc62a..a51c70c 100644 --- a/src/litmodels/io/cloud.py +++ b/src/litmodels/io/cloud.py @@ -2,9 +2,10 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0 # -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from lightning_sdk.api.teamspace_api import UploadedModelInfo +from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL from lightning_sdk.teamspace import Teamspace from lightning_sdk.utils import resolve as sdk_resolvers @@ -15,6 +16,8 @@ # else: # LightningModule = None +_SHOWED_MODEL_LINKS = [] + def _parse_name(name: str) -> Tuple[str, str, str]: """Parse the name argument into its components.""" @@ -50,11 +53,34 @@ def _get_teamspace(name: str, organization: str) -> Teamspace: return Teamspace(**teamspaces[requested_teamspace]) +def _print_model_link(org_name: str, teamspace_name: str, model_name: str, verbose: Union[bool, int]) -> None: + """Print a link to the uploaded model. + + Args: + org_name: Name of the organization. + teamspace_name: Name of the teamspace. + model_name: Name of the model. + verbose: Whether to print the link: + + - If set to 0, no link will be printed. + - If set to 1, the link will be printed only once. + - If set to 2, the link will be printed every time. + """ + url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}" + msg = f"Model uploaded successfully. Link to the model: '{url}'" + if int(verbose) > 1: + print(msg) + elif url not in _SHOWED_MODEL_LINKS: + print(msg) + _SHOWED_MODEL_LINKS.append(url) + + def upload_model_files( name: str, path: str, progress_bar: bool = True, cluster_id: Optional[str] = None, + verbose: Union[bool, int] = 0, ) -> UploadedModelInfo: """Upload a local checkpoint file to the model store. @@ -65,16 +91,20 @@ def upload_model_files( progress_bar: Whether to show a progress bar for the upload. cluster_id: The name of the cluster to use. Only required if it can't be determined automatically. + verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed. """ org_name, teamspace_name, model_name = _parse_name(name) teamspace = _get_teamspace(name=teamspace_name, organization=org_name) - return teamspace.upload_model( + info = teamspace.upload_model( path=path, name=model_name, progress_bar=progress_bar, cluster_id=cluster_id, ) + if verbose: + _print_model_link(org_name, teamspace_name, model_name, verbose) + return info def download_model_files( diff --git a/src/litmodels/io/gateway.py b/src/litmodels/io/gateway.py index 5f00f58..4e0ac4f 100644 --- a/src/litmodels/io/gateway.py +++ b/src/litmodels/io/gateway.py @@ -21,6 +21,7 @@ def upload_model( progress_bar: bool = True, cluster_id: Optional[str] = None, staging_dir: Optional[str] = None, + verbose: Union[bool, int] = 0, ) -> UploadedModelInfo: """Upload a checkpoint to the model store. @@ -33,6 +34,7 @@ def upload_model( automatically. staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will be created and used. + verbose: Whether to print some additional information about the uploaded model. """ if not staging_dir: @@ -54,6 +56,7 @@ def upload_model( name=name, progress_bar=progress_bar, cluster_id=cluster_id, + verbose=verbose, ) diff --git a/tests/test_io_cloud.py b/tests/test_io_cloud.py index 126f076..630f448 100644 --- a/tests/test_io_cloud.py +++ b/tests/test_io_cloud.py @@ -16,14 +16,14 @@ def test_wrong_model_name(name): @pytest.mark.parametrize( - ("model", "model_path"), + ("model", "model_path", "verbose"), [ - ("path/to/checkpoint", "path/to/checkpoint"), + ("path/to/checkpoint", "path/to/checkpoint", False), # (BoringModel(), "%s/BoringModel.ckpt"), - (Module(), f"%s{os.path.sep}Module.pth"), + (Module(), f"%s{os.path.sep}Module.pth", True), ], ) -def test_upload_model(mocker, tmpdir, model, model_path): +def test_upload_model(mocker, tmpdir, model, model_path, verbose): # mocking the _get_teamspace to return another mock ts_mock = mock.MagicMock() mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock) @@ -34,6 +34,7 @@ def test_upload_model(mocker, tmpdir, model, model_path): name="org-name/teamspace/model-name", cluster_id="cluster_id", staging_dir=tmpdir, + verbose=verbose, ) expected_path = model_path % str(tmpdir) if "%" in model_path else model_path ts_mock.upload_model.assert_called_once_with(