Skip to content

Commit 4c381c6

Browse files
split upload_model to save & upload (#99)
* split `upload_model` to save & upload * noqa: F401 * mock * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0eecc43 commit 4c381c6

File tree

7 files changed

+73
-26
lines changed

7 files changed

+73
-26
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,17 @@ Save model:
6666

6767
```python
6868
import torch
69-
from litmodels import load_model, upload_model
69+
from litmodels import save_model
7070

7171
model = torch.nn.Module()
72-
upload_model(model=model, name="your_org/your_team/torch-model")
72+
save_model(model=model, name="your_org/your_team/torch-model")
7373
```
7474

7575
Load model:
7676

7777
```python
78+
from litmodels import load_model
79+
7880
model_ = load_model(name="your_org/your_team/torch-model")
7981
```
8082

@@ -131,7 +133,7 @@ Save model:
131133
```python
132134
from tensorflow import keras
133135

134-
from litmodels import upload_model
136+
from litmodels import save_model
135137

136138
# Define the model
137139
model = keras.Sequential(
@@ -145,7 +147,7 @@ model = keras.Sequential(
145147
model.compile(optimizer="adam", loss="categorical_crossentropy")
146148

147149
# Save the model
148-
upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
150+
save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
149151
```
150152

151153
Load model:
@@ -167,7 +169,7 @@ Save model:
167169

168170
```python
169171
from sklearn import datasets, model_selection, svm
170-
from litmodels import upload_model
172+
from litmodels import save_model
171173

172174
# Load example dataset
173175
iris = datasets.load_iris()
@@ -183,7 +185,7 @@ model = svm.SVC()
183185
model.fit(X_train, y_train)
184186

185187
# Upload the saved model using litmodels
186-
upload_model(model=model, name="your_org/your_team/sklearn-svm-model")
188+
save_model(model=model, name="your_org/your_team/sklearn-svm-model")
187189
```
188190

189191
Use model:

examples/demo-tensorflow-keras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from tensorflow import keras
22

3-
from litmodels import load_model, upload_model
3+
from litmodels import load_model, save_model
44

55
if __name__ == "__main__":
66
# Define the model
@@ -13,7 +13,7 @@
1313
model.compile(optimizer="adam", loss="categorical_crossentropy")
1414

1515
# Save the model
16-
upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
16+
save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
1717

1818
# Load the model
1919
model_ = load_model("lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model")

src/litmodels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
_PACKAGE_ROOT = os.path.dirname(__file__)
88
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
99

10-
from litmodels.io import download_model, load_model, upload_model
10+
from litmodels.io import download_model, load_model, save_model, upload_model
1111

12-
__all__ = ["download_model", "upload_model", "load_model"]
12+
__all__ = ["download_model", "upload_model", "load_model", "save_model"]

src/litmodels/io/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Root package for Input/output."""
22

3-
from litmodels.io.cloud import download_model_files, upload_model_files
4-
from litmodels.io.gateway import download_model, load_model, upload_model
3+
from litmodels.io.cloud import download_model_files, upload_model_files # noqa: F401
4+
from litmodels.io.gateway import download_model, load_model, save_model, upload_model
55

6-
__all__ = ["download_model", "upload_model", "download_model_files", "upload_model_files", "load_model"]
6+
__all__ = ["download_model", "upload_model", "load_model", "save_model"]

src/litmodels/io/gateway.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,44 @@
1818

1919
def upload_model(
2020
name: str,
21-
model: Union[str, Path, "torch.nn.Module", Any],
21+
model: Union[str, Path],
22+
progress_bar: bool = True,
23+
cloud_account: Optional[str] = None,
24+
verbose: Union[bool, int] = 1,
25+
metadata: Optional[Dict[str, str]] = None,
26+
) -> "UploadedModelInfo":
27+
"""Upload a checkpoint to the model store.
28+
29+
Args:
30+
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
31+
where entity is either your username or the name of an organization you are part of.
32+
model: The model to upload. Can be a path to a checkpoint file or a folder.
33+
progress_bar: Whether to show a progress bar for the upload.
34+
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
35+
automatically.
36+
verbose: Whether to print some additional information about the uploaded model.
37+
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
38+
39+
"""
40+
if not isinstance(model, (str, Path)):
41+
raise ValueError(
42+
"The `model` argument should be a path to a file or folder, not an python object."
43+
" For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead."
44+
)
45+
46+
return upload_model_files(
47+
path=model,
48+
name=name,
49+
progress_bar=progress_bar,
50+
cloud_account=cloud_account,
51+
verbose=verbose,
52+
metadata=metadata,
53+
)
54+
55+
56+
def save_model(
57+
name: str,
58+
model: Union["torch.nn.Module", Any],
2259
progress_bar: bool = True,
2360
cloud_account: Optional[str] = None,
2461
staging_dir: Optional[str] = None,
@@ -30,7 +67,7 @@ def upload_model(
3067
Args:
3168
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
3269
where entity is either your username or the name of an organization you are part of.
33-
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
70+
model: The model to upload. Can be a PyTorch model, or a Lightning model a.
3471
progress_bar: Whether to show a progress bar for the upload.
3572
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
3673
automatically.
@@ -40,14 +77,18 @@ def upload_model(
4077
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
4178
4279
"""
80+
if isinstance(model, (str, Path)):
81+
raise ValueError(
82+
"The `model` argument should be a PyTorch model or a Lightning model, not a path to a file."
83+
" With file or folder path use `upload_model` instead."
84+
)
85+
4386
if not staging_dir:
4487
staging_dir = tempfile.mkdtemp()
45-
if isinstance(model, (str, Path)):
46-
path = model
4788
# if LightningModule and isinstance(model, LightningModule):
4889
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
4990
# model.save_checkpoint(path)
50-
elif _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule):
91+
if _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule):
5192
path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts")
5293
model.save(path)
5394
elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module):
@@ -60,8 +101,12 @@ def upload_model(
60101
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
61102
dump_pickle(model=model, path=path)
62103

63-
return upload_model_files(
64-
path=path,
104+
if not metadata:
105+
metadata = {}
106+
metadata.update({"litModels_integration": "save_model"})
107+
108+
return upload_model(
109+
model=path,
65110
name=name,
66111
progress_bar=progress_bar,
67112
cloud_account=cloud_account,

tests/integrations/test_real_cloud.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
1111
from lightning_sdk.utils.resolve import _resolve_teamspace
1212

13-
from litmodels import download_model, load_model, upload_model
13+
from litmodels import download_model, load_model, save_model, upload_model
1414
from litmodels.integrations.duplicate import duplicate_hf_model
1515
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
1616
from litmodels.io.cloud import _list_available_teamspaces
@@ -349,7 +349,7 @@ def test_save_load_tensorflow_keras(tmp_path):
349349

350350
# model name with random hash
351351
teamspace, org_team, model_name = _prepare_variables("tf-keras")
352-
upload_model(f"{org_team}/{model_name}", model=model)
352+
save_model(f"{org_team}/{model_name}", model=model)
353353

354354
# Load the model
355355
model_ = load_model(f"{org_team}/{model_name}", download_dir=str(tmp_path))

tests/test_io_cloud.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.nn import Module
1111

1212
import litmodels
13-
from litmodels import download_model, load_model, upload_model
13+
from litmodels import download_model, load_model, save_model
1414
from litmodels.io import upload_model_files
1515
from litmodels.io.utils import _KERAS_AVAILABLE
1616
from tests.integrations import LIT_ORG, LIT_TEAMSPACE
@@ -59,7 +59,7 @@ def test_download_wrong_model_name(name, in_studio, monkeypatch):
5959
@pytest.mark.parametrize(
6060
("model", "model_path", "verbose"),
6161
[
62-
("path/to/checkpoint", "path/to/checkpoint", False),
62+
# ("path/to/checkpoint", "path/to/checkpoint", False),
6363
# (BoringModel(), "%s/BoringModel.ckpt"),
6464
(torch_jit.script(Module()), f"%s{os.path.sep}RecursiveScriptModule.ts", True),
6565
(Module(), f"%s{os.path.sep}Module.pth", True),
@@ -71,7 +71,7 @@ def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose):
7171
mock_upload_model.return_value.name = "org-name/teamspace/model-name"
7272

7373
# The lit-logger function is just a wrapper around the SDK function
74-
upload_model(
74+
save_model(
7575
model=model,
7676
name="org-name/teamspace/model-name",
7777
cloud_account="cluster_id",
@@ -84,7 +84,7 @@ def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose):
8484
name="org-name/teamspace/model-name",
8585
cloud_account="cluster_id",
8686
progress_bar=True,
87-
metadata={"litModels": litmodels.__version__},
87+
metadata={"litModels": litmodels.__version__, "litModels_integration": "save_model"},
8888
)
8989

9090

0 commit comments

Comments
 (0)