Skip to content

Commit 9b1971b

Browse files
committed
download_model
1 parent 92f8298 commit 9b1971b

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

examples/demo-upload-download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717

1818
# Download the model checkpoint
19-
model_path = litmodels.download_model_files("jirka/kaggle/boring-model", download_dir="./my-models")
19+
model_path = litmodels.download_model("jirka/kaggle/boring-model", download_dir="./my-models")
2020
print(f"Model downloaded to {model_path}")
2121

2222
# Load the model checkpoint

examples/train-resume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch.utils.data as data
22
import torchvision as tv
33
from lightning import Trainer
4-
from litmodels import download_model_files
4+
from litmodels import download_model
55
from sample_model import LitAutoEncoder
66

77
if __name__ == "__main__":
88
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
99
train, val = data.random_split(dataset, [55000, 5000])
1010

11-
model_path = download_model_files(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
11+
model_path = download_model(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
1212
print(f"model: {model_path}")
1313
# autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path)
1414

src/litmodels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
_PACKAGE_ROOT = os.path.dirname(__file__)
88
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
99

10-
from litmodels.cloud_io import download_model_files, upload_model, upload_model_files
10+
from litmodels.cloud_io import download_model, upload_model, upload_model_files
1111

12-
__all__ = ["download_model_files", "upload_model", "upload_model_files"]
12+
__all__ = ["download_model", "upload_model", "upload_model_files"]

src/litmodels/cloud_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def upload_model_files(
132132
)
133133

134134

135-
def download_model_files(
135+
def download_model(
136136
name: str,
137137
download_dir: str = ".",
138138
progress_bar: bool = True,

tests/test_cloud_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from unittest import mock
33

44
import pytest
5-
from litmodels.cloud_io import download_model_files, upload_model, upload_model_files
5+
from litmodels.cloud_io import download_model, upload_model, upload_model_files
66
from torch.nn import Module
77

88

@@ -11,7 +11,7 @@ def test_wrong_model_name(name):
1111
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
1212
upload_model_files(path="path/to/checkpoint", name=name)
1313
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
14-
download_model_files(name=name)
14+
download_model(name=name)
1515

1616

1717
@pytest.mark.parametrize(
@@ -48,7 +48,7 @@ def test_download_model(mocker):
4848
ts_mock = mock.MagicMock()
4949
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
5050
# The lit-logger function is just a wrapper around the SDK function
51-
download_model_files(
51+
download_model(
5252
name="org-name/teamspace/model-name",
5353
download_dir="where/to/download",
5454
)

0 commit comments

Comments
 (0)