Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main
with:
actions-ref: main
extra-typing: "typing"

check-schema:
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
Expand All @@ -33,7 +34,7 @@ jobs:
testing-matrix: |
{
"os": ["ubuntu-latest", "macos-latest", "windows-latest"],
"python-version": ["3.8", "3.10"]
"python-version": ["3.9", "3.12"]
}

check-docs:
Expand Down
1 change: 0 additions & 1 deletion _requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@ coverage >=5.0
pytest >=6.0
pytest-cov
pytest-mock
mypy ==1.13.0

pytorch-lightning >=2.0
1 change: 1 addition & 0 deletions _requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mypy ==1.13.0
4 changes: 2 additions & 2 deletions src/litmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from litmodels.cloud_io import download_model, upload_model
from litmodels.cloud_io import download_model, upload_model, upload_model_files

__all__ = ["download_model", "upload_model"]
__all__ = ["download_model", "upload_model", "upload_model_files"]
65 changes: 63 additions & 2 deletions src/litmodels/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#

from typing import Optional, Tuple
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Union

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers
from lightning_utilities import module_available

if TYPE_CHECKING:
from torch.nn import Module

if module_available("torch"):
import torch
from torch.nn import Module
else:
torch = None

# if module_available("lightning"):
# from lightning import LightningModule
# elif module_available("pytorch_lightning"):
# from pytorch_lightning import LightningModule
# else:
# LightningModule = None


def _parse_name(name: str) -> Tuple[str, str, str]:
Expand Down Expand Up @@ -45,6 +64,48 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:


def upload_model(
model: Union[str, Path, "Module"],
name: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
staging_dir: Optional[str] = None,
) -> UploadedModelInfo:
"""Upload a checkpoint to the model store.

Args:
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
be created and used.

"""
if not staging_dir:
staging_dir = tempfile.mkdtemp()
# if LightningModule and isinstance(model, LightningModule):
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
# model.save_checkpoint(path)
if torch and isinstance(model, Module):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
torch.save(model.state_dict(), path)
elif isinstance(model, str):
path = model
elif isinstance(model, Path):
path = str(model)
else:
raise ValueError(f"Unsupported model type {type(model)}")
return upload_model_files(
path=path,
name=name,
progress_bar=progress_bar,
cluster_id=cluster_id,
)


def upload_model_files(
path: str,
name: str,
progress_bar: bool = True,
Expand Down
22 changes: 17 additions & 5 deletions tests/test_cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
import os
from unittest import mock

import pytest
from litmodels.cloud_io import download_model, upload_model
from litmodels.cloud_io import download_model, upload_model, upload_model_files
from torch.nn import Module


@pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"])
def test_wrong_model_name(name):
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
upload_model(path="path/to/checkpoint", name=name)
upload_model_files(path="path/to/checkpoint", name=name)
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
download_model(name=name)


def test_upload_model(mocker):
@pytest.mark.parametrize(
("model", "model_path"),
[
("path/to/checkpoint", "path/to/checkpoint"),
# (BoringModel(), "%s/BoringModel.ckpt"),
(Module(), f"%s{os.path.sep}Module.pth"),
],
)
def test_upload_model(mocker, tmpdir, model, model_path):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)

# The lit-logger function is just a wrapper around the SDK function
upload_model(
path="path/to/checkpoint",
model,
name="org-name/teamspace/model-name",
cluster_id="cluster_id",
staging_dir=tmpdir,
)
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
ts_mock.upload_model.assert_called_once_with(
path="path/to/checkpoint",
path=expected_path,
name="model-name",
cluster_id="cluster_id",
progress_bar=True,
Expand Down
Loading