Skip to content

Commit 3ff0752

Browse files
committed
cleaning
1 parent 9b0ef2d commit 3ff0752

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

src/litmodels/io/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Root package for Input/output."""
22

3+
from litmodels.io.cloud import download_model_file, upload_model_file
34
from litmodels.io.gateway import download_model, upload_model
45

5-
__all__ = ["download_model", "upload_model"]
6+
__all__ = ["download_model", "upload_model", "download_model_file", "upload_model_file"]

src/litmodels/io/cloud.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
5050
return Teamspace(**teamspaces[requested_teamspace])
5151

5252

53-
def upload_model_files(
53+
def upload_model_file(
5454
name: str,
5555
path: str,
5656
progress_bar: bool = True,

tests/test_cloud_io.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
from unittest import mock
33

44
import pytest
5-
from litmodels.cloud_io import download_model, upload_model, upload_model_files
5+
from litmodels import download_model, upload_model
6+
from litmodels.io import upload_model_file
67
from torch.nn import Module
78

89

910
@pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"])
1011
def test_wrong_model_name(name):
1112
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
12-
upload_model_files(path="path/to/checkpoint", name=name)
13+
upload_model_file(path="path/to/checkpoint", name=name)
1314
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
1415
download_model(name=name)
1516

@@ -25,7 +26,7 @@ def test_wrong_model_name(name):
2526
def test_upload_model(mocker, tmpdir, model, model_path):
2627
# mocking the _get_teamspace to return another mock
2728
ts_mock = mock.MagicMock()
28-
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
29+
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
2930

3031
# The lit-logger function is just a wrapper around the SDK function
3132
upload_model(

0 commit comments

Comments
 (0)