Skip to content

Commit 0eba032

Browse files
Bordaethanwharris
andauthored
feat: enable upload Torch's nn.Module (#14)
Co-authored-by: Ethan Harris <[email protected]>
1 parent 4110d2b commit 0eba032

File tree

6 files changed

+85
-11
lines changed

6 files changed

+85
-11
lines changed

.github/workflows/ci-checks.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main
1919
with:
2020
actions-ref: main
21+
extra-typing: "typing"
2122

2223
check-schema:
2324
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
@@ -33,7 +34,7 @@ jobs:
3334
testing-matrix: |
3435
{
3536
"os": ["ubuntu-latest", "macos-latest", "windows-latest"],
36-
"python-version": ["3.8", "3.10"]
37+
"python-version": ["3.9", "3.12"]
3738
}
3839
3940
check-docs:

_requirements/test.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@ coverage >=5.0
22
pytest >=6.0
33
pytest-cov
44
pytest-mock
5-
mypy ==1.13.0
65

76
pytorch-lightning >=2.0

_requirements/typing.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
mypy ==1.13.0

src/litmodels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
_PACKAGE_ROOT = os.path.dirname(__file__)
88
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
99

10-
from litmodels.cloud_io import download_model, upload_model
10+
from litmodels.cloud_io import download_model, upload_model, upload_model_files
1111

12-
__all__ = ["download_model", "upload_model"]
12+
__all__ = ["download_model", "upload_model", "upload_model_files"]

src/litmodels/cloud_io.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,31 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5-
6-
from typing import Optional, Tuple
5+
import os
6+
import tempfile
7+
from pathlib import Path
8+
from typing import TYPE_CHECKING, Optional, Tuple, Union
79

810
from lightning_sdk.api.teamspace_api import UploadedModelInfo
911
from lightning_sdk.teamspace import Teamspace
1012
from lightning_sdk.utils import resolve as sdk_resolvers
13+
from lightning_utilities import module_available
14+
15+
if TYPE_CHECKING:
16+
from torch.nn import Module
17+
18+
if module_available("torch"):
19+
import torch
20+
from torch.nn import Module
21+
else:
22+
torch = None
23+
24+
# if module_available("lightning"):
25+
# from lightning import LightningModule
26+
# elif module_available("pytorch_lightning"):
27+
# from pytorch_lightning import LightningModule
28+
# else:
29+
# LightningModule = None
1130

1231

1332
def _parse_name(name: str) -> Tuple[str, str, str]:
@@ -45,6 +64,48 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
4564

4665

4766
def upload_model(
67+
model: Union[str, Path, "Module"],
68+
name: str,
69+
progress_bar: bool = True,
70+
cluster_id: Optional[str] = None,
71+
staging_dir: Optional[str] = None,
72+
) -> UploadedModelInfo:
73+
"""Upload a checkpoint to the model store.
74+
75+
Args:
76+
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
77+
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
78+
where entity is either your username or the name of an organization you are part of.
79+
progress_bar: Whether to show a progress bar for the upload.
80+
cluster_id: The name of the cluster to use. Only required if it can't be determined
81+
automatically.
82+
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
83+
be created and used.
84+
85+
"""
86+
if not staging_dir:
87+
staging_dir = tempfile.mkdtemp()
88+
# if LightningModule and isinstance(model, LightningModule):
89+
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90+
# model.save_checkpoint(path)
91+
if torch and isinstance(model, Module):
92+
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
93+
torch.save(model.state_dict(), path)
94+
elif isinstance(model, str):
95+
path = model
96+
elif isinstance(model, Path):
97+
path = str(model)
98+
else:
99+
raise ValueError(f"Unsupported model type {type(model)}")
100+
return upload_model_files(
101+
path=path,
102+
name=name,
103+
progress_bar=progress_bar,
104+
cluster_id=cluster_id,
105+
)
106+
107+
108+
def upload_model_files(
48109
path: str,
49110
name: str,
50111
progress_bar: bool = True,

tests/test_cloud_io.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,42 @@
1+
import os
12
from unittest import mock
23

34
import pytest
4-
from litmodels.cloud_io import download_model, upload_model
5+
from litmodels.cloud_io import download_model, upload_model, upload_model_files
6+
from torch.nn import Module
57

68

79
@pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"])
810
def test_wrong_model_name(name):
911
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
10-
upload_model(path="path/to/checkpoint", name=name)
12+
upload_model_files(path="path/to/checkpoint", name=name)
1113
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
1214
download_model(name=name)
1315

1416

15-
def test_upload_model(mocker):
17+
@pytest.mark.parametrize(
18+
("model", "model_path"),
19+
[
20+
("path/to/checkpoint", "path/to/checkpoint"),
21+
# (BoringModel(), "%s/BoringModel.ckpt"),
22+
(Module(), f"%s{os.path.sep}Module.pth"),
23+
],
24+
)
25+
def test_upload_model(mocker, tmpdir, model, model_path):
1626
# mocking the _get_teamspace to return another mock
1727
ts_mock = mock.MagicMock()
1828
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
1929

2030
# The lit-logger function is just a wrapper around the SDK function
2131
upload_model(
22-
path="path/to/checkpoint",
32+
model,
2333
name="org-name/teamspace/model-name",
2434
cluster_id="cluster_id",
35+
staging_dir=tmpdir,
2536
)
37+
expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
2638
ts_mock.upload_model.assert_called_once_with(
27-
path="path/to/checkpoint",
39+
path=expected_path,
2840
name="model-name",
2941
cluster_id="cluster_id",
3042
progress_bar=True,

0 commit comments

Comments
 (0)