Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main
with:
actions-ref: main
extra-typing: "typing"

check-schema:
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
Expand All @@ -33,7 +34,7 @@ jobs:
testing-matrix: |
{
"os": ["ubuntu-latest", "macos-latest", "windows-latest"],
"python-version": ["3.8", "3.10"]
"python-version": ["3.9", "3.12"]
}

check-docs:
Expand Down
1 change: 0 additions & 1 deletion _requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@ coverage >=5.0
pytest >=6.0
pytest-cov
pytest-mock
mypy ==1.13.0

pytorch-lightning >=2.0
1 change: 1 addition & 0 deletions _requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mypy ==1.13.0
2 changes: 1 addition & 1 deletion examples/demo-upload-download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

# Download the model checkpoint
model_path = litmodels.download_model("jirka/kaggle/boring-model", download_dir="./my-models")
model_path = litmodels.download_model_files("jirka/kaggle/boring-model", download_dir="./my-models")
print(f"Model downloaded to {model_path}")

# Load the model checkpoint
Expand Down
4 changes: 2 additions & 2 deletions examples/train-resume.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch.utils.data as data
import torchvision as tv
from lightning import Trainer
from litmodels import download_model
from litmodels import download_model_files
from sample_model import LitAutoEncoder

if __name__ == "__main__":
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

model_path = download_model(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
model_path = download_model_files(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
print(f"model: {model_path}")
# autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path)

Expand Down
4 changes: 2 additions & 2 deletions src/litmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from litmodels.cloud_io import download_model, upload_model
from litmodels.cloud_io import download_model_files, upload_model, upload_model_files

__all__ = ["download_model", "upload_model"]
__all__ = ["download_model_files", "upload_model", "upload_model_files"]
67 changes: 64 additions & 3 deletions src/litmodels/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#

from typing import Optional, Tuple
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Union

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers
from lightning_utilities import module_available

if TYPE_CHECKING:
from torch.nn import Module

if module_available("torch"):
import torch
from torch.nn import Module
else:
torch = None

# if module_available("lightning"):
# from lightning import LightningModule
# elif module_available("pytorch_lightning"):
# from pytorch_lightning import LightningModule
# else:
# LightningModule = None


def _parse_name(name: str) -> Tuple[str, str, str]:
Expand Down Expand Up @@ -45,6 +64,48 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:


def upload_model(
model: Union[str, Path, "Module"],
name: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
staging_dir: Optional[str] = None,
) -> UploadedModelInfo:
"""Upload a local checkpoint file to the model store.

Args:
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
be created and used.

"""
if not staging_dir:
staging_dir = tempfile.mkdtemp()
# if LightningModule and isinstance(model, LightningModule):
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
# model.save_checkpoint(path)
elif torch and isinstance(model, Module):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
torch.save(model.state_dict(), path)
elif isinstance(model, str):
path = model
elif isinstance(model, Path):
path = str(model)
else:
raise ValueError(f"Unsupported model type {type(model)}")
return upload_model_files(
path=path,
name=name,
progress_bar=progress_bar,
cluster_id=cluster_id,
)


def upload_model_files(
path: str,
name: str,
progress_bar: bool = True,
Expand All @@ -71,7 +132,7 @@ def upload_model(
)


def download_model(
def download_model_files(
name: str,
download_dir: str = ".",
progress_bar: bool = True,
Expand Down
30 changes: 21 additions & 9 deletions tests/test_cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
import os
from unittest import mock

import pytest
from litmodels.cloud_io import download_model, upload_model
from litmodels.cloud_io import download_model_files, upload_model, upload_model_files
from torch.nn import Module


@pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"])
def test_wrong_model_name(name):
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
upload_model(path="path/to/checkpoint", name=name)
upload_model_files(path="path/to/checkpoint", name=name)
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
download_model(name=name)


def test_upload_model(mocker):
download_model_files(name=name)


@pytest.mark.parametrize(
("model", "model_path"),
[
("path/to/checkpoint", "path/to/checkpoint"),
# (BoringModel(), "%s/BoringModel.ckpt"),
(Module(), f"%s{os.path.sep}Module.pth"),
],
)
def test_upload_model(mocker, tmpdir, model, model_path):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)

# The lit-logger function is just a wrapper around the SDK function
upload_model(
path="path/to/checkpoint",
model,
name="org-name/teamspace/model-name",
cluster_id="cluster_id",
staging_dir=tmpdir,
)
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
ts_mock.upload_model.assert_called_once_with(
path="path/to/checkpoint",
path=expected_path,
name="model-name",
cluster_id="cluster_id",
progress_bar=True,
Expand All @@ -36,7 +48,7 @@ def test_download_model(mocker):
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
# The lit-logger function is just a wrapper around the SDK function
download_model(
download_model_files(
name="org-name/teamspace/model-name",
download_dir="where/to/download",
)
Expand Down
Loading