diff --git a/src/litmodels/io/cloud.py b/src/litmodels/io/cloud.py index bae2e04..235057e 100644 --- a/src/litmodels/io/cloud.py +++ b/src/litmodels/io/cloud.py @@ -5,7 +5,8 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL -from lightning_sdk.models import download_model, upload_model +from lightning_sdk.models import download_model as sdk_download_model +from lightning_sdk.models import upload_model as sdk_upload_model if TYPE_CHECKING: from lightning_sdk.models import UploadedModelInfo @@ -71,7 +72,7 @@ def upload_model_files( verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed. """ - info = upload_model( + info = sdk_upload_model( name=name, path=path, progress_bar=progress_bar, @@ -99,7 +100,7 @@ def download_model_files( Returns: The absolute path to the downloaded model file or folder. """ - return download_model( + return sdk_download_model( name=name, download_dir=download_dir, progress_bar=progress_bar, diff --git a/tests/test_io_cloud.py b/tests/test_io_cloud.py index 9ff8740..457226d 100644 --- a/tests/test_io_cloud.py +++ b/tests/test_io_cloud.py @@ -23,7 +23,7 @@ def test_wrong_model_name(name): (Module(), f"%s{os.path.sep}Module.pth", True), ], ) -@mock.patch("litmodels.io.cloud.upload_model") +@mock.patch("litmodels.io.cloud.sdk_upload_model") def test_upload_model(mock_upload_model, tmpdir, model, model_path, verbose): mock_upload_model.return_value.name = "org-name/teamspace/model-name" @@ -44,7 +44,7 @@ def test_upload_model(mock_upload_model, tmpdir, model, model_path, verbose): ) -@mock.patch("litmodels.io.cloud.download_model") +@mock.patch("litmodels.io.cloud.sdk_download_model") def test_download_model(mock_download_model): # The lit-logger function is just a wrapper around the SDK function download_model(