Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main
with:
actions-ref: main
extra-typing: "typing"

check-schema:
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
Expand All @@ -33,7 +34,7 @@ jobs:
testing-matrix: |
{
"os": ["ubuntu-latest", "macos-latest", "windows-latest"],
"python-version": ["3.8", "3.10"]
"python-version": ["3.9", "3.12"]
}

check-docs:
Expand Down
1 change: 0 additions & 1 deletion _requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@ coverage >=5.0
pytest >=6.0
pytest-cov
pytest-mock
mypy ==1.13.0

pytorch-lightning >=2.0
1 change: 1 addition & 0 deletions _requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mypy ==1.13.0
4 changes: 2 additions & 2 deletions src/litmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from litmodels.cloud_io import download_model, upload_model
from litmodels.cloud_io import download_model, upload_model, upload_model_files

__all__ = ["download_model", "upload_model"]
__all__ = ["download_model", "upload_model", "upload_model_files"]
65 changes: 63 additions & 2 deletions src/litmodels/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#

from typing import Optional, Tuple
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Union

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers
from lightning_utilities import module_available

if TYPE_CHECKING:
from torch.nn import Module

if module_available("torch"):
import torch
from torch.nn import Module
else:
torch = None

# if module_available("lightning"):
# from lightning import LightningModule
# elif module_available("pytorch_lightning"):
# from pytorch_lightning import LightningModule
# else:
# LightningModule = None


def _parse_name(name: str) -> Tuple[str, str, str]:
Expand Down Expand Up @@ -45,6 +64,48 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:


def upload_model(
model: Union[str, Path, "Module"],
name: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
staging_dir: Optional[str] = None,
) -> UploadedModelInfo:
"""Upload a local checkpoint file to the model store.

Args:
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
be created and used.

"""
if not staging_dir:
staging_dir = tempfile.mkdtemp()
# if LightningModule and isinstance(model, LightningModule):
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
# model.save_checkpoint(path)
if torch and isinstance(model, Module):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
torch.save(model.state_dict(), path)
elif isinstance(model, str):
path = model
elif isinstance(model, Path):
path = str(model)
else:
raise ValueError(f"Unsupported model type {type(model)}")
return upload_model_files(
path=path,
name=name,
progress_bar=progress_bar,
cluster_id=cluster_id,
)


def upload_model_files(
path: str,
name: str,
progress_bar: bool = True,
Expand Down
22 changes: 17 additions & 5 deletions tests/test_cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
import os
from unittest import mock

import pytest
from litmodels.cloud_io import download_model, upload_model
from litmodels.cloud_io import download_model, upload_model, upload_model_files
from torch.nn import Module


@pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"])
def test_wrong_model_name(name):
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
upload_model(path="path/to/checkpoint", name=name)
upload_model_files(path="path/to/checkpoint", name=name)
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
download_model(name=name)


def test_upload_model(mocker):
@pytest.mark.parametrize(
("model", "model_path"),
[
("path/to/checkpoint", "path/to/checkpoint"),
# (BoringModel(), "%s/BoringModel.ckpt"),
(Module(), f"%s{os.path.sep}Module.pth"),
],
)
def test_upload_model(mocker, tmpdir, model, model_path):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)

# The lit-logger function is just a wrapper around the SDK function
upload_model(
path="path/to/checkpoint",
model,
name="org-name/teamspace/model-name",
cluster_id="cluster_id",
staging_dir=tmpdir,
)
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
ts_mock.upload_model.assert_called_once_with(
path="path/to/checkpoint",
path=expected_path,
name="model-name",
cluster_id="cluster_id",
progress_bar=True,
Expand Down