diff --git a/src/litmodels/io/cloud.py b/src/litmodels/io/cloud.py index ad31b00..bae2e04 100644 --- a/src/litmodels/io/cloud.py +++ b/src/litmodels/io/cloud.py @@ -2,11 +2,10 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0 # -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL -from lightning_sdk.teamspace import Teamspace -from lightning_sdk.utils import resolve as sdk_resolvers +from lightning_sdk.models import download_model, upload_model if TYPE_CHECKING: from lightning_sdk.models import UploadedModelInfo @@ -32,42 +31,18 @@ def _parse_name(name: str) -> Tuple[str, str, str]: return org_name, teamspace_name, model_name -def _get_teamspace(name: str, organization: str) -> Teamspace: - """Get a Teamspace object from the SDK.""" - from lightning_sdk.api import OrgApi, UserApi - - org_api = OrgApi() - user = sdk_resolvers._get_authed_user() - teamspaces = {} - for ts in UserApi()._get_all_teamspace_memberships(""): - if ts.owner_type == "organization": - org = org_api._get_org_by_id(ts.owner_id) - teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name} - elif ts.owner_type == "user": # todo: check also the name - teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user} - else: - raise RuntimeError(f"Unknown organization type {ts.organization_type}") - - requested_teamspace = f"{organization}/{name}".lower() - if requested_teamspace not in teamspaces: - options = "\n\t".join(teamspaces.keys()) - raise RuntimeError(f"Teamspace `{requested_teamspace}` not found. Available teamspaces: \n\t{options}") - return Teamspace(**teamspaces[requested_teamspace]) - - -def _print_model_link(org_name: str, teamspace_name: str, model_name: str, verbose: Union[bool, int]) -> None: +def _print_model_link(name: str, verbose: Union[bool, int]) -> None: """Print a link to the uploaded model. Args: - org_name: Name of the organization. - teamspace_name: Name of the teamspace. - model_name: Name of the model. + name: Name of the model. verbose: Whether to print the link: - If set to 0, no link will be printed. - If set to 1, the link will be printed only once. - If set to 2, the link will be printed every time. """ + org_name, teamspace_name, model_name = _parse_name(name) url = f"{LIGHTNING_CLOUD_URL}/{org_name}/{teamspace_name}/models/{model_name}" msg = f"Model uploaded successfully. Link to the model: '{url}'" if int(verbose) > 1: @@ -81,7 +56,7 @@ def upload_model_files( name: str, path: str, progress_bar: bool = True, - cluster_id: Optional[str] = None, + cloud_account: Optional[str] = None, verbose: Union[bool, int] = 1, ) -> "UploadedModelInfo": """Upload a local checkpoint file to the model store. @@ -91,21 +66,19 @@ def upload_model_files( where entity is either your username or the name of an organization you are part of. path: Path to the model file to upload. progress_bar: Whether to show a progress bar for the upload. - cluster_id: The name of the cluster to use. Only required if it can't be determined + cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined automatically. verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed. """ - org_name, teamspace_name, model_name = _parse_name(name) - teamspace = _get_teamspace(name=teamspace_name, organization=org_name) - info = teamspace.upload_model( + info = upload_model( + name=name, path=path, - name=model_name, progress_bar=progress_bar, - cluster_id=cluster_id, + cloud_account=cloud_account, ) if verbose: - _print_model_link(org_name, teamspace_name, model_name, verbose) + _print_model_link(info.name, verbose) return info @@ -113,7 +86,7 @@ def download_model_files( name: str, download_dir: str = ".", progress_bar: bool = True, -) -> str: +) -> Union[str, List[str]]: """Download a checkpoint from the model store. Args: @@ -126,10 +99,8 @@ def download_model_files( Returns: The absolute path to the downloaded model file or folder. """ - org_name, teamspace_name, model_name = _parse_name(name) - teamspace = _get_teamspace(name=teamspace_name, organization=org_name) - return teamspace.download_model( - name=model_name, + return download_model( + name=name, download_dir=download_dir, progress_bar=progress_bar, ) diff --git a/src/litmodels/io/gateway.py b/src/litmodels/io/gateway.py index cd437c0..662fba5 100644 --- a/src/litmodels/io/gateway.py +++ b/src/litmodels/io/gateway.py @@ -1,7 +1,7 @@ import os import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from lightning_utilities import module_available @@ -21,7 +21,7 @@ def upload_model( name: str, model: Union[str, Path, "Module"], progress_bar: bool = True, - cluster_id: Optional[str] = None, + cloud_account: Optional[str] = None, staging_dir: Optional[str] = None, verbose: Union[bool, int] = 1, ) -> "UploadedModelInfo": @@ -32,7 +32,7 @@ def upload_model( where entity is either your username or the name of an organization you are part of. model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model. progress_bar: Whether to show a progress bar for the upload. - cluster_id: The name of the cluster to use. Only required if it can't be determined + cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined automatically. staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will be created and used. @@ -57,7 +57,7 @@ def upload_model( path=path, name=name, progress_bar=progress_bar, - cluster_id=cluster_id, + cloud_account=cloud_account, verbose=verbose, ) @@ -66,7 +66,7 @@ def download_model( name: str, download_dir: str = ".", progress_bar: bool = True, -) -> str: +) -> Union[str, List[str]]: """Download a checkpoint from the model store. Args: diff --git a/tests/test_io_cloud.py b/tests/test_io_cloud.py index 630f448..9ff8740 100644 --- a/tests/test_io_cloud.py +++ b/tests/test_io_cloud.py @@ -23,37 +23,34 @@ def test_wrong_model_name(name): (Module(), f"%s{os.path.sep}Module.pth", True), ], ) -def test_upload_model(mocker, tmpdir, model, model_path, verbose): - # mocking the _get_teamspace to return another mock - ts_mock = mock.MagicMock() - mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock) +@mock.patch("litmodels.io.cloud.upload_model") +def test_upload_model(mock_upload_model, tmpdir, model, model_path, verbose): + mock_upload_model.return_value.name = "org-name/teamspace/model-name" # The lit-logger function is just a wrapper around the SDK function upload_model( model=model, name="org-name/teamspace/model-name", - cluster_id="cluster_id", + cloud_account="cluster_id", staging_dir=tmpdir, verbose=verbose, ) expected_path = model_path % str(tmpdir) if "%" in model_path else model_path - ts_mock.upload_model.assert_called_once_with( + mock_upload_model.assert_called_once_with( path=expected_path, - name="model-name", - cluster_id="cluster_id", + name="org-name/teamspace/model-name", + cloud_account="cluster_id", progress_bar=True, ) -def test_download_model(mocker): - # mocking the _get_teamspace to return another mock - ts_mock = mock.MagicMock() - mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock) +@mock.patch("litmodels.io.cloud.download_model") +def test_download_model(mock_download_model): # The lit-logger function is just a wrapper around the SDK function download_model( name="org-name/teamspace/model-name", download_dir="where/to/download", ) - ts_mock.download_model.assert_called_once_with( - name="model-name", download_dir="where/to/download", progress_bar=True + mock_download_model.assert_called_once_with( + name="org-name/teamspace/model-name", download_dir="where/to/download", progress_bar=True )