Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,17 @@ Save model:

```python
import torch
from litmodels import load_model, upload_model
from litmodels import save_model

model = torch.nn.Module()
upload_model(model=model, name="your_org/your_team/torch-model")
save_model(model=model, name="your_org/your_team/torch-model")
```

Load model:

```python
from litmodels import load_model

model_ = load_model(name="your_org/your_team/torch-model")
```

Expand Down Expand Up @@ -131,7 +133,7 @@ Save model:
```python
from tensorflow import keras

from litmodels import upload_model
from litmodels import save_model

# Define the model
model = keras.Sequential(
Expand All @@ -145,7 +147,7 @@ model = keras.Sequential(
model.compile(optimizer="adam", loss="categorical_crossentropy")

# Save the model
upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
```

Load model:
Expand All @@ -167,7 +169,7 @@ Save model:

```python
from sklearn import datasets, model_selection, svm
from litmodels import upload_model
from litmodels import save_model

# Load example dataset
iris = datasets.load_iris()
Expand All @@ -183,7 +185,7 @@ model = svm.SVC()
model.fit(X_train, y_train)

# Upload the saved model using litmodels
upload_model(model=model, name="your_org/your_team/sklearn-svm-model")
save_model(model=model, name="your_org/your_team/sklearn-svm-model")
```

Use model:
Expand Down
4 changes: 2 additions & 2 deletions examples/demo-tensorflow-keras.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tensorflow import keras

from litmodels import load_model, upload_model
from litmodels import load_model, save_model

if __name__ == "__main__":
# Define the model
Expand All @@ -13,7 +13,7 @@
model.compile(optimizer="adam", loss="categorical_crossentropy")

# Save the model
upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)

# Load the model
model_ = load_model("lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model")
4 changes: 2 additions & 2 deletions src/litmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from litmodels.io import download_model, load_model, upload_model
from litmodels.io import download_model, load_model, save_model, upload_model

__all__ = ["download_model", "upload_model", "load_model"]
__all__ = ["download_model", "upload_model", "load_model", "save_model"]
6 changes: 3 additions & 3 deletions src/litmodels/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Root package for Input/output."""

from litmodels.io.cloud import download_model_files, upload_model_files
from litmodels.io.gateway import download_model, load_model, upload_model
from litmodels.io.cloud import download_model_files, upload_model_files # noqa: F401
from litmodels.io.gateway import download_model, load_model, save_model, upload_model

__all__ = ["download_model", "upload_model", "download_model_files", "upload_model_files", "load_model"]
__all__ = ["download_model", "upload_model", "load_model", "save_model"]
59 changes: 52 additions & 7 deletions src/litmodels/io/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,44 @@

def upload_model(
name: str,
model: Union[str, Path, "torch.nn.Module", Any],
model: Union[str, Path],
progress_bar: bool = True,
cloud_account: Optional[str] = None,
verbose: Union[bool, int] = 1,
metadata: Optional[Dict[str, str]] = None,
) -> "UploadedModelInfo":
"""Upload a checkpoint to the model store.

Args:
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
model: The model to upload. Can be a path to a checkpoint file or a folder.
progress_bar: Whether to show a progress bar for the upload.
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
automatically.
verbose: Whether to print some additional information about the uploaded model.
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.

"""
if not isinstance(model, (str, Path)):
raise ValueError(
"The `model` argument should be a path to a file or folder, not an python object."
" For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead."
)

return upload_model_files(
path=model,
name=name,
progress_bar=progress_bar,
cloud_account=cloud_account,
verbose=verbose,
metadata=metadata,
)


def save_model(
name: str,
model: Union["torch.nn.Module", Any],
progress_bar: bool = True,
cloud_account: Optional[str] = None,
staging_dir: Optional[str] = None,
Expand All @@ -30,7 +67,7 @@ def upload_model(
Args:
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
model: The model to upload. Can be a PyTorch model, or a Lightning model a.
progress_bar: Whether to show a progress bar for the upload.
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
automatically.
Expand All @@ -40,14 +77,18 @@ def upload_model(
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.

"""
if isinstance(model, (str, Path)):
raise ValueError(
"The `model` argument should be a PyTorch model or a Lightning model, not a path to a file."
" With file or folder path use `upload_model` instead."
)

if not staging_dir:
staging_dir = tempfile.mkdtemp()
if isinstance(model, (str, Path)):
path = model
# if LightningModule and isinstance(model, LightningModule):
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
# model.save_checkpoint(path)
elif _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule):
if _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts")
model.save(path)
elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module):
Expand All @@ -60,8 +101,12 @@ def upload_model(
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
dump_pickle(model=model, path=path)

return upload_model_files(
path=path,
if not metadata:
metadata = {}
metadata.update({"litModels_integration": "save_model"})

return upload_model(
model=path,
name=name,
progress_bar=progress_bar,
cloud_account=cloud_account,
Expand Down
4 changes: 2 additions & 2 deletions tests/integrations/test_real_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
from lightning_sdk.utils.resolve import _resolve_teamspace

from litmodels import download_model, load_model, upload_model
from litmodels import download_model, load_model, save_model, upload_model
from litmodels.integrations.duplicate import duplicate_hf_model
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
from litmodels.io.cloud import _list_available_teamspaces
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_save_load_tensorflow_keras(tmp_path):

# model name with random hash
teamspace, org_team, model_name = _prepare_variables("tf-keras")
upload_model(f"{org_team}/{model_name}", model=model)
save_model(f"{org_team}/{model_name}", model=model)

# Load the model
model_ = load_model(f"{org_team}/{model_name}", download_dir=str(tmp_path))
Expand Down
8 changes: 4 additions & 4 deletions tests/test_io_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.nn import Module

import litmodels
from litmodels import download_model, load_model, upload_model
from litmodels import download_model, load_model, save_model
from litmodels.io import upload_model_files
from litmodels.io.utils import _KERAS_AVAILABLE
from tests.integrations import LIT_ORG, LIT_TEAMSPACE
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_download_wrong_model_name(name, in_studio, monkeypatch):
@pytest.mark.parametrize(
("model", "model_path", "verbose"),
[
("path/to/checkpoint", "path/to/checkpoint", False),
# ("path/to/checkpoint", "path/to/checkpoint", False),
# (BoringModel(), "%s/BoringModel.ckpt"),
(torch_jit.script(Module()), f"%s{os.path.sep}RecursiveScriptModule.ts", True),
(Module(), f"%s{os.path.sep}Module.pth", True),
Expand All @@ -71,7 +71,7 @@ def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose):
mock_upload_model.return_value.name = "org-name/teamspace/model-name"

# The lit-logger function is just a wrapper around the SDK function
upload_model(
save_model(
model=model,
name="org-name/teamspace/model-name",
cloud_account="cluster_id",
Expand All @@ -84,7 +84,7 @@ def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose):
name="org-name/teamspace/model-name",
cloud_account="cluster_id",
progress_bar=True,
metadata={"litModels": litmodels.__version__},
metadata={"litModels": litmodels.__version__, "litModels_integration": "save_model"},
)


Expand Down
Loading