Skip to content

Commit 35ee017

Browse files
committed
feat: enable upload Torch's nn.Module
1 parent 6bfa873 commit 35ee017

File tree

7 files changed

+94
-22
lines changed

7 files changed

+94
-22
lines changed

examples/demo-upload-download.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
torch.save(model.state_dict(), "./boring-checkpoint.pt")
1111

1212
# Upload the model checkpoint
13-
litmodels.upload_model(
13+
litmodels.upload_model_files(
1414
"./boring-checkpoint.pt",
1515
"jirka/kaggle/boring-model",
1616
)
1717

1818
# Download the model checkpoint
19-
model_path = litmodels.download_model("jirka/kaggle/boring-model", download_dir="./my-models")
19+
model_path = litmodels.download_model_files("jirka/kaggle/boring-model", download_dir="./my-models")
2020
print(f"Model downloaded to {model_path}")
2121

2222
# Load the model checkpoint

examples/train-callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch.utils.data as data
22
import torchvision as tv
33
from lightning import Callback, Trainer
4-
from litmodels import upload_model
4+
from litmodels import upload_model_files
55
from sample_model import LitAutoEncoder
66

77

@@ -11,7 +11,7 @@ def on_train_epoch_end(self, trainer, pl_module):
1111
best_model_path = trainer.checkpoint_callback.best_model_path
1212
if best_model_path:
1313
print(f"Uploading model: {best_model_path}")
14-
upload_model(path=best_model_path, name="jirka/kaggle/lit-auto-encoder-callback")
14+
upload_model_files(path=best_model_path, name="jirka/kaggle/lit-auto-encoder-callback")
1515

1616

1717
if __name__ == "__main__":

examples/train-resume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch.utils.data as data
22
import torchvision as tv
33
from lightning import Trainer
4-
from litmodels import download_model
4+
from litmodels import download_model_files
55
from sample_model import LitAutoEncoder
66

77
if __name__ == "__main__":
88
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
99
train, val = data.random_split(dataset, [55000, 5000])
1010

11-
model_path = download_model(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
11+
model_path = download_model_files(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
1212
print(f"model: {model_path}")
1313
# autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path)
1414

examples/train-simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torchvision as tv
33
from lightning import Trainer
44
from lightning.pytorch.callbacks import ModelCheckpoint
5-
from litmodels import upload_model
5+
from litmodels import upload_model_files
66
from sample_model import LitAutoEncoder
77

88
if __name__ == "__main__":
@@ -30,4 +30,4 @@
3030
data.DataLoader(val, batch_size=256),
3131
)
3232
print(f"last: {vars(checkpoint_callback)}")
33-
upload_model(path=checkpoint_callback.last_model_path, name="jirka/kaggle/lit-auto-encoder-simple")
33+
upload_model_files(path=checkpoint_callback.last_model_path, name="jirka/kaggle/lit-auto-encoder-simple")

src/litmodels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
_PACKAGE_ROOT = os.path.dirname(__file__)
88
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
99

10-
from litmodels.cloud_io import download_model, upload_model
10+
from litmodels.cloud_io import download_model_files, upload_model_files
1111

12-
__all__ = ["download_model", "upload_model"]
12+
__all__ = ["download_model_files", "upload_model_files"]

src/litmodels/cloud_io.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,31 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5-
6-
from typing import Optional, Tuple
5+
import os
6+
import tempfile
7+
from pathlib import Path
8+
from typing import TYPE_CHECKING, Optional, Tuple, Union
79

810
from lightning_sdk.api.teamspace_api import UploadedModelInfo
911
from lightning_sdk.teamspace import Teamspace
1012
from lightning_sdk.utils import resolve as sdk_resolvers
13+
from lightning_utilities import module_available
14+
15+
if TYPE_CHECKING:
16+
from torch import nn
17+
18+
if module_available("torch"):
19+
import torch
20+
from torch import nn
21+
else:
22+
torch = None
23+
24+
# if module_available("lightning"):
25+
# from lightning import LightningModule
26+
# elif module_available("pytorch_lightning"):
27+
# from pytorch_lightning import LightningModule
28+
# else:
29+
# LightningModule = None
1130

1231

1332
def _parse_name(name: str) -> Tuple[str, str, str]:
@@ -45,6 +64,48 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
4564

4665

4766
def upload_model(
67+
model: Union[str, Path, nn.Module],
68+
name: str,
69+
progress_bar: bool = True,
70+
cluster_id: Optional[str] = None,
71+
staging_dir: Optional[str] = None,
72+
) -> UploadedModelInfo:
73+
"""Upload a local checkpoint file to the model store.
74+
75+
Args:
76+
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
77+
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
78+
where entity is either your username or the name of an organization you are part of.
79+
progress_bar: Whether to show a progress bar for the upload.
80+
cluster_id: The name of the cluster to use. Only required if it can't be determined
81+
automatically.
82+
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
83+
be created and used.
84+
85+
"""
86+
if not staging_dir:
87+
staging_dir = tempfile.mkdtemp()
88+
# if LightningModule and isinstance(model, LightningModule):
89+
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90+
# model.save_checkpoint(path)
91+
elif torch and isinstance(model, nn.Module):
92+
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
93+
torch.save(model.state_dict(), path)
94+
elif isinstance(model, str):
95+
path = model
96+
elif isinstance(model, Path):
97+
path = str(model)
98+
else:
99+
raise ValueError(f"Unsupported model type {type(model)}")
100+
return upload_model_files(
101+
path=path,
102+
name=name,
103+
progress_bar=progress_bar,
104+
cluster_id=cluster_id,
105+
)
106+
107+
108+
def upload_model_files(
48109
path: str,
49110
name: str,
50111
progress_bar: bool = True,
@@ -71,7 +132,7 @@ def upload_model(
71132
)
72133

73134

74-
def download_model(
135+
def download_model_files(
75136
name: str,
76137
download_dir: str = ".",
77138
progress_bar: bool = True,

tests/test_cloud_io.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,41 @@
11
from unittest import mock
22

33
import pytest
4-
from litmodels.cloud_io import download_model, upload_model
4+
from litmodels.cloud_io import download_model_files, upload_model, upload_model_files
5+
from torch.nn import Module
56

67

78
@pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"])
89
def test_wrong_model_name(name):
910
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
10-
upload_model(path="path/to/checkpoint", name=name)
11+
upload_model_files(path="path/to/checkpoint", name=name)
1112
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
12-
download_model(name=name)
13-
14-
15-
def test_upload_model(mocker):
13+
download_model_files(name=name)
14+
15+
16+
@pytest.mark.parametrize(
17+
"model, model_path",
18+
[
19+
("path/to/checkpoint", "path/to/checkpoint"),
20+
# (BoringModel(), "%s/BoringModel.ckpt"),
21+
(Module(), "%s/Module.pth"),
22+
],
23+
)
24+
def test_upload_model(mocker, tmpdir, model, model_path):
1625
# mocking the _get_teamspace to return another mock
1726
ts_mock = mock.MagicMock()
1827
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
1928

2029
# The lit-logger function is just a wrapper around the SDK function
2130
upload_model(
22-
path="path/to/checkpoint",
31+
model,
2332
name="org-name/teamspace/model-name",
2433
cluster_id="cluster_id",
34+
staging_dir=tmpdir,
2535
)
36+
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
2637
ts_mock.upload_model.assert_called_once_with(
27-
path="path/to/checkpoint",
38+
path=expected_path,
2839
name="model-name",
2940
cluster_id="cluster_id",
3041
progress_bar=True,
@@ -36,7 +47,7 @@ def test_download_model(mocker):
3647
ts_mock = mock.MagicMock()
3748
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
3849
# The lit-logger function is just a wrapper around the SDK function
39-
download_model(
50+
download_model_files(
4051
name="org-name/teamspace/model-name",
4152
download_dir="where/to/download",
4253
)

0 commit comments

Comments
 (0)