Skip to content

Commit a77779d

Browse files
authored
update args & docs (#16)
1 parent 1917842 commit a77779d

File tree

5 files changed

+100
-67
lines changed

5 files changed

+100
-67
lines changed

src/litmodels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
_PACKAGE_ROOT = os.path.dirname(__file__)
88
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
99

10-
from litmodels.cloud_io import download_model, upload_model, upload_model_files
10+
from litmodels.io import download_model, upload_model
1111

12-
__all__ = ["download_model", "upload_model", "upload_model_files"]
12+
__all__ = ["download_model", "upload_model"]

src/litmodels/io/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Root package for Input/output."""
2+
3+
from litmodels.io.cloud import download_model_files, upload_model_files
4+
from litmodels.io.gateway import download_model, upload_model
5+
6+
__all__ = ["download_model", "upload_model", "download_model_files", "upload_model_files"]
Lines changed: 6 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,11 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5-
import os
6-
import tempfile
7-
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional, Tuple, Union
5+
from typing import Optional, Tuple
96

107
from lightning_sdk.api.teamspace_api import UploadedModelInfo
118
from lightning_sdk.teamspace import Teamspace
129
from lightning_sdk.utils import resolve as sdk_resolvers
13-
from lightning_utilities import module_available
14-
15-
if TYPE_CHECKING:
16-
from torch.nn import Module
17-
18-
if module_available("torch"):
19-
import torch
20-
from torch.nn import Module
21-
else:
22-
torch = None
2310

2411
# if module_available("lightning"):
2512
# from lightning import LightningModule
@@ -63,60 +50,18 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
6350
return Teamspace(**teamspaces[requested_teamspace])
6451

6552

66-
def upload_model(
67-
model: Union[str, Path, "Module"],
68-
name: str,
69-
progress_bar: bool = True,
70-
cluster_id: Optional[str] = None,
71-
staging_dir: Optional[str] = None,
72-
) -> UploadedModelInfo:
73-
"""Upload a checkpoint to the model store.
74-
75-
Args:
76-
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
77-
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
78-
where entity is either your username or the name of an organization you are part of.
79-
progress_bar: Whether to show a progress bar for the upload.
80-
cluster_id: The name of the cluster to use. Only required if it can't be determined
81-
automatically.
82-
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
83-
be created and used.
84-
85-
"""
86-
if not staging_dir:
87-
staging_dir = tempfile.mkdtemp()
88-
# if LightningModule and isinstance(model, LightningModule):
89-
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90-
# model.save_checkpoint(path)
91-
if torch and isinstance(model, Module):
92-
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
93-
torch.save(model.state_dict(), path)
94-
elif isinstance(model, str):
95-
path = model
96-
elif isinstance(model, Path):
97-
path = str(model)
98-
else:
99-
raise ValueError(f"Unsupported model type {type(model)}")
100-
return upload_model_files(
101-
path=path,
102-
name=name,
103-
progress_bar=progress_bar,
104-
cluster_id=cluster_id,
105-
)
106-
107-
10853
def upload_model_files(
109-
path: str,
11054
name: str,
55+
path: str,
11156
progress_bar: bool = True,
11257
cluster_id: Optional[str] = None,
11358
) -> UploadedModelInfo:
11459
"""Upload a local checkpoint file to the model store.
11560
11661
Args:
117-
path: Path to the model file to upload.
118-
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
62+
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
11963
where entity is either your username or the name of an organization you are part of.
64+
path: Path to the model file to upload.
12065
progress_bar: Whether to show a progress bar for the upload.
12166
cluster_id: The name of the cluster to use. Only required if it can't be determined
12267
automatically.
@@ -132,15 +77,15 @@ def upload_model_files(
13277
)
13378

13479

135-
def download_model(
80+
def download_model_files(
13681
name: str,
13782
download_dir: str = ".",
13883
progress_bar: bool = True,
13984
) -> str:
14085
"""Download a checkpoint from the model store.
14186
14287
Args:
143-
name: Name tag of the model to download. Must be in the format 'organization/teamspace/modelname'
88+
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
14489
where entity is either your username or the name of an organization you are part of.
14590
download_dir: A path to directory where the model should be downloaded. Defaults
14691
to the current working directory.

src/litmodels/io/gateway.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import tempfile
3+
from pathlib import Path
4+
from typing import Optional, Union
5+
6+
from lightning_sdk.api.teamspace_api import UploadedModelInfo
7+
from lightning_utilities import module_available
8+
9+
from litmodels.io.cloud import download_model_files, upload_model_files
10+
11+
if module_available("torch"):
12+
import torch
13+
from torch.nn import Module
14+
else:
15+
torch = None
16+
17+
18+
def upload_model(
19+
name: str,
20+
model: Union[str, Path, "Module"],
21+
progress_bar: bool = True,
22+
cluster_id: Optional[str] = None,
23+
staging_dir: Optional[str] = None,
24+
) -> UploadedModelInfo:
25+
"""Upload a checkpoint to the model store.
26+
27+
Args:
28+
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
29+
where entity is either your username or the name of an organization you are part of.
30+
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
31+
progress_bar: Whether to show a progress bar for the upload.
32+
cluster_id: The name of the cluster to use. Only required if it can't be determined
33+
automatically.
34+
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
35+
be created and used.
36+
37+
"""
38+
if not staging_dir:
39+
staging_dir = tempfile.mkdtemp()
40+
# if LightningModule and isinstance(model, LightningModule):
41+
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
42+
# model.save_checkpoint(path)
43+
if torch and isinstance(model, Module):
44+
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
45+
torch.save(model.state_dict(), path)
46+
elif isinstance(model, str):
47+
path = model
48+
elif isinstance(model, Path):
49+
path = str(model)
50+
else:
51+
raise ValueError(f"Unsupported model type {type(model)}")
52+
return upload_model_files(
53+
path=path,
54+
name=name,
55+
progress_bar=progress_bar,
56+
cluster_id=cluster_id,
57+
)
58+
59+
60+
def download_model(
61+
name: str,
62+
download_dir: str = ".",
63+
progress_bar: bool = True,
64+
) -> str:
65+
"""Download a checkpoint from the model store.
66+
67+
Args:
68+
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
69+
where entity is either your username or the name of an organization you are part of.
70+
download_dir: A path to directory where the model should be downloaded. Defaults
71+
to the current working directory.
72+
progress_bar: Whether to show a progress bar for the download.
73+
74+
Returns:
75+
The absolute path to the downloaded model file or folder.
76+
"""
77+
return download_model_files(
78+
name=name,
79+
download_dir=download_dir,
80+
progress_bar=progress_bar,
81+
)
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from unittest import mock
33

44
import pytest
5-
from litmodels.cloud_io import download_model, upload_model, upload_model_files
5+
from litmodels import download_model, upload_model
6+
from litmodels.io import upload_model_files
67
from torch.nn import Module
78

89

@@ -25,11 +26,11 @@ def test_wrong_model_name(name):
2526
def test_upload_model(mocker, tmpdir, model, model_path):
2627
# mocking the _get_teamspace to return another mock
2728
ts_mock = mock.MagicMock()
28-
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
29+
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
2930

3031
# The lit-logger function is just a wrapper around the SDK function
3132
upload_model(
32-
model,
33+
model=model,
3334
name="org-name/teamspace/model-name",
3435
cluster_id="cluster_id",
3536
staging_dir=tmpdir,
@@ -46,7 +47,7 @@ def test_upload_model(mocker, tmpdir, model, model_path):
4647
def test_download_model(mocker):
4748
# mocking the _get_teamspace to return another mock
4849
ts_mock = mock.MagicMock()
49-
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
50+
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
5051
# The lit-logger function is just a wrapper around the SDK function
5152
download_model(
5253
name="org-name/teamspace/model-name",

0 commit comments

Comments
 (0)