diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index 6e965e4..3839d06 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -18,6 +18,7 @@ jobs: uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main with: actions-ref: main + extra-typing: "typing" check-schema: uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main @@ -33,7 +34,7 @@ jobs: testing-matrix: | { "os": ["ubuntu-latest", "macos-latest", "windows-latest"], - "python-version": ["3.8", "3.10"] + "python-version": ["3.9", "3.12"] } check-docs: diff --git a/_requirements/test.txt b/_requirements/test.txt index 4b56175..7dc51ac 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -2,6 +2,5 @@ coverage >=5.0 pytest >=6.0 pytest-cov pytest-mock -mypy ==1.13.0 pytorch-lightning >=2.0 diff --git a/_requirements/typing.txt b/_requirements/typing.txt new file mode 100644 index 0000000..04a9506 --- /dev/null +++ b/_requirements/typing.txt @@ -0,0 +1 @@ +mypy ==1.13.0 diff --git a/src/litmodels/__init__.py b/src/litmodels/__init__.py index 8e7a118..9584ebd 100644 --- a/src/litmodels/__init__.py +++ b/src/litmodels/__init__.py @@ -7,6 +7,6 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from litmodels.cloud_io import download_model, upload_model +from litmodels.cloud_io import download_model, upload_model, upload_model_files -__all__ = ["download_model", "upload_model"] +__all__ = ["download_model", "upload_model", "upload_model_files"] diff --git a/src/litmodels/cloud_io.py b/src/litmodels/cloud_io.py index 2dad671..12c691b 100644 --- a/src/litmodels/cloud_io.py +++ b/src/litmodels/cloud_io.py @@ -2,12 +2,31 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0 # - -from typing import Optional, Tuple +import os +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Tuple, Union from lightning_sdk.api.teamspace_api import UploadedModelInfo from lightning_sdk.teamspace import Teamspace from lightning_sdk.utils import resolve as sdk_resolvers +from lightning_utilities import module_available + +if TYPE_CHECKING: + from torch.nn import Module + +if module_available("torch"): + import torch + from torch.nn import Module +else: + torch = None + +# if module_available("lightning"): +# from lightning import LightningModule +# elif module_available("pytorch_lightning"): +# from pytorch_lightning import LightningModule +# else: +# LightningModule = None def _parse_name(name: str) -> Tuple[str, str, str]: @@ -45,6 +64,48 @@ def _get_teamspace(name: str, organization: str) -> Teamspace: def upload_model( + model: Union[str, Path, "Module"], + name: str, + progress_bar: bool = True, + cluster_id: Optional[str] = None, + staging_dir: Optional[str] = None, +) -> UploadedModelInfo: + """Upload a checkpoint to the model store. + + Args: + model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model. + name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname' + where entity is either your username or the name of an organization you are part of. + progress_bar: Whether to show a progress bar for the upload. + cluster_id: The name of the cluster to use. Only required if it can't be determined + automatically. + staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will + be created and used. + + """ + if not staging_dir: + staging_dir = tempfile.mkdtemp() + # if LightningModule and isinstance(model, LightningModule): + # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt") + # model.save_checkpoint(path) + if torch and isinstance(model, Module): + path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth") + torch.save(model.state_dict(), path) + elif isinstance(model, str): + path = model + elif isinstance(model, Path): + path = str(model) + else: + raise ValueError(f"Unsupported model type {type(model)}") + return upload_model_files( + path=path, + name=name, + progress_bar=progress_bar, + cluster_id=cluster_id, + ) + + +def upload_model_files( path: str, name: str, progress_bar: bool = True, diff --git a/tests/test_cloud_io.py b/tests/test_cloud_io.py index ee5c9f2..f4c595e 100644 --- a/tests/test_cloud_io.py +++ b/tests/test_cloud_io.py @@ -1,30 +1,42 @@ +import os from unittest import mock import pytest -from litmodels.cloud_io import download_model, upload_model +from litmodels.cloud_io import download_model, upload_model, upload_model_files +from torch.nn import Module @pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"]) def test_wrong_model_name(name): with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"): - upload_model(path="path/to/checkpoint", name=name) + upload_model_files(path="path/to/checkpoint", name=name) with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"): download_model(name=name) -def test_upload_model(mocker): +@pytest.mark.parametrize( + ("model", "model_path"), + [ + ("path/to/checkpoint", "path/to/checkpoint"), + # (BoringModel(), "%s/BoringModel.ckpt"), + (Module(), f"%s{os.path.sep}Module.pth"), + ], +) +def test_upload_model(mocker, tmpdir, model, model_path): # mocking the _get_teamspace to return another mock ts_mock = mock.MagicMock() mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock) # The lit-logger function is just a wrapper around the SDK function upload_model( - path="path/to/checkpoint", + model, name="org-name/teamspace/model-name", cluster_id="cluster_id", + staging_dir=tmpdir, ) + expected_path = model_path % str(tmpdir) if "%" in model_path else model_path ts_mock.upload_model.assert_called_once_with( - path="path/to/checkpoint", + path=expected_path, name="model-name", cluster_id="cluster_id", progress_bar=True,