diff --git a/README.md b/README.md index ac30544..583be10 100644 --- a/README.md +++ b/README.md @@ -66,15 +66,17 @@ Save model: ```python import torch -from litmodels import load_model, upload_model +from litmodels import save_model model = torch.nn.Module() -upload_model(model=model, name="your_org/your_team/torch-model") +save_model(model=model, name="your_org/your_team/torch-model") ``` Load model: ```python +from litmodels import load_model + model_ = load_model(name="your_org/your_team/torch-model") ``` @@ -131,7 +133,7 @@ Save model: ```python from tensorflow import keras -from litmodels import upload_model +from litmodels import save_model # Define the model model = keras.Sequential( @@ -145,7 +147,7 @@ model = keras.Sequential( model.compile(optimizer="adam", loss="categorical_crossentropy") # Save the model -upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model) +save_model("lightning-ai/jirka/sample-tf-keras-model", model=model) ``` Load model: @@ -167,7 +169,7 @@ Save model: ```python from sklearn import datasets, model_selection, svm -from litmodels import upload_model +from litmodels import save_model # Load example dataset iris = datasets.load_iris() @@ -183,7 +185,7 @@ model = svm.SVC() model.fit(X_train, y_train) # Upload the saved model using litmodels -upload_model(model=model, name="your_org/your_team/sklearn-svm-model") +save_model(model=model, name="your_org/your_team/sklearn-svm-model") ``` Use model: diff --git a/examples/demo-tensorflow-keras.py b/examples/demo-tensorflow-keras.py index 20f3419..a472fb6 100644 --- a/examples/demo-tensorflow-keras.py +++ b/examples/demo-tensorflow-keras.py @@ -1,6 +1,6 @@ from tensorflow import keras -from litmodels import load_model, upload_model +from litmodels import load_model, save_model if __name__ == "__main__": # Define the model @@ -13,7 +13,7 @@ model.compile(optimizer="adam", loss="categorical_crossentropy") # Save the model - upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model) + save_model("lightning-ai/jirka/sample-tf-keras-model", model=model) # Load the model model_ = load_model("lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model") diff --git a/src/litmodels/__init__.py b/src/litmodels/__init__.py index 3f83645..e4f9786 100644 --- a/src/litmodels/__init__.py +++ b/src/litmodels/__init__.py @@ -7,6 +7,6 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from litmodels.io import download_model, load_model, upload_model +from litmodels.io import download_model, load_model, save_model, upload_model -__all__ = ["download_model", "upload_model", "load_model"] +__all__ = ["download_model", "upload_model", "load_model", "save_model"] diff --git a/src/litmodels/io/__init__.py b/src/litmodels/io/__init__.py index 9a0975d..5f9a793 100644 --- a/src/litmodels/io/__init__.py +++ b/src/litmodels/io/__init__.py @@ -1,6 +1,6 @@ """Root package for Input/output.""" -from litmodels.io.cloud import download_model_files, upload_model_files -from litmodels.io.gateway import download_model, load_model, upload_model +from litmodels.io.cloud import download_model_files, upload_model_files # noqa: F401 +from litmodels.io.gateway import download_model, load_model, save_model, upload_model -__all__ = ["download_model", "upload_model", "download_model_files", "upload_model_files", "load_model"] +__all__ = ["download_model", "upload_model", "load_model", "save_model"] diff --git a/src/litmodels/io/gateway.py b/src/litmodels/io/gateway.py index 49f18b5..ea909eb 100644 --- a/src/litmodels/io/gateway.py +++ b/src/litmodels/io/gateway.py @@ -18,7 +18,44 @@ def upload_model( name: str, - model: Union[str, Path, "torch.nn.Module", Any], + model: Union[str, Path], + progress_bar: bool = True, + cloud_account: Optional[str] = None, + verbose: Union[bool, int] = 1, + metadata: Optional[Dict[str, str]] = None, +) -> "UploadedModelInfo": + """Upload a checkpoint to the model store. + + Args: + name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' + where entity is either your username or the name of an organization you are part of. + model: The model to upload. Can be a path to a checkpoint file or a folder. + progress_bar: Whether to show a progress bar for the upload. + cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined + automatically. + verbose: Whether to print some additional information about the uploaded model. + metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. + + """ + if not isinstance(model, (str, Path)): + raise ValueError( + "The `model` argument should be a path to a file or folder, not an python object." + " For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead." + ) + + return upload_model_files( + path=model, + name=name, + progress_bar=progress_bar, + cloud_account=cloud_account, + verbose=verbose, + metadata=metadata, + ) + + +def save_model( + name: str, + model: Union["torch.nn.Module", Any], progress_bar: bool = True, cloud_account: Optional[str] = None, staging_dir: Optional[str] = None, @@ -30,7 +67,7 @@ def upload_model( Args: name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' where entity is either your username or the name of an organization you are part of. - model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model. + model: The model to upload. Can be a PyTorch model, or a Lightning model a. progress_bar: Whether to show a progress bar for the upload. cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined automatically. @@ -40,14 +77,18 @@ def upload_model( metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used. """ + if isinstance(model, (str, Path)): + raise ValueError( + "The `model` argument should be a PyTorch model or a Lightning model, not a path to a file." + " With file or folder path use `upload_model` instead." + ) + if not staging_dir: staging_dir = tempfile.mkdtemp() - if isinstance(model, (str, Path)): - path = model # if LightningModule and isinstance(model, LightningModule): # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt") # model.save_checkpoint(path) - elif _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule): + if _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule): path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts") model.save(path) elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module): @@ -60,8 +101,12 @@ def upload_model( path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl") dump_pickle(model=model, path=path) - return upload_model_files( - path=path, + if not metadata: + metadata = {} + metadata.update({"litModels_integration": "save_model"}) + + return upload_model( + model=path, name=name, progress_bar=progress_bar, cloud_account=cloud_account, diff --git a/tests/integrations/test_real_cloud.py b/tests/integrations/test_real_cloud.py index 93c7577..fd9a113 100644 --- a/tests/integrations/test_real_cloud.py +++ b/tests/integrations/test_real_cloud.py @@ -10,7 +10,7 @@ from lightning_sdk.lightning_cloud.rest_client import GridRestClient from lightning_sdk.utils.resolve import _resolve_teamspace -from litmodels import download_model, load_model, upload_model +from litmodels import download_model, load_model, save_model, upload_model from litmodels.integrations.duplicate import duplicate_hf_model from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin from litmodels.io.cloud import _list_available_teamspaces @@ -349,7 +349,7 @@ def test_save_load_tensorflow_keras(tmp_path): # model name with random hash teamspace, org_team, model_name = _prepare_variables("tf-keras") - upload_model(f"{org_team}/{model_name}", model=model) + save_model(f"{org_team}/{model_name}", model=model) # Load the model model_ = load_model(f"{org_team}/{model_name}", download_dir=str(tmp_path)) diff --git a/tests/test_io_cloud.py b/tests/test_io_cloud.py index 58da01c..5222641 100644 --- a/tests/test_io_cloud.py +++ b/tests/test_io_cloud.py @@ -10,7 +10,7 @@ from torch.nn import Module import litmodels -from litmodels import download_model, load_model, upload_model +from litmodels import download_model, load_model, save_model from litmodels.io import upload_model_files from litmodels.io.utils import _KERAS_AVAILABLE from tests.integrations import LIT_ORG, LIT_TEAMSPACE @@ -59,7 +59,7 @@ def test_download_wrong_model_name(name, in_studio, monkeypatch): @pytest.mark.parametrize( ("model", "model_path", "verbose"), [ - ("path/to/checkpoint", "path/to/checkpoint", False), + # ("path/to/checkpoint", "path/to/checkpoint", False), # (BoringModel(), "%s/BoringModel.ckpt"), (torch_jit.script(Module()), f"%s{os.path.sep}RecursiveScriptModule.ts", True), (Module(), f"%s{os.path.sep}Module.pth", True), @@ -71,7 +71,7 @@ def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose): mock_upload_model.return_value.name = "org-name/teamspace/model-name" # The lit-logger function is just a wrapper around the SDK function - upload_model( + save_model( model=model, name="org-name/teamspace/model-name", cloud_account="cluster_id", @@ -84,7 +84,7 @@ def test_upload_model(mock_upload_model, tmp_path, model, model_path, verbose): name="org-name/teamspace/model-name", cloud_account="cluster_id", progress_bar=True, - metadata={"litModels": litmodels.__version__}, + metadata={"litModels": litmodels.__version__, "litModels_integration": "save_model"}, )