Skip to content

Commit 9e07f13

Browse files
committed
update args & docs
1 parent 1917842 commit 9e07f13

File tree

4 files changed

+87
-56
lines changed

4 files changed

+87
-56
lines changed

src/litmodels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
_PACKAGE_ROOT = os.path.dirname(__file__)
88
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
99

10-
from litmodels.cloud_io import download_model, upload_model, upload_model_files
10+
from litmodels.io import download_model, upload_model
1111

12-
__all__ = ["download_model", "upload_model", "upload_model_files"]
12+
__all__ = ["download_model", "upload_model"]

src/litmodels/io/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from litmodels.io.gateway import download_model, upload_model
2+
3+
__all__ = ["download_model", "upload_model"]
Lines changed: 8 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,18 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5-
import os
6-
import tempfile
7-
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Optional, Tuple
96

107
from lightning_sdk.api.teamspace_api import UploadedModelInfo
118
from lightning_sdk.teamspace import Teamspace
129
from lightning_sdk.utils import resolve as sdk_resolvers
1310
from lightning_utilities import module_available
1411

1512
if TYPE_CHECKING:
16-
from torch.nn import Module
13+
pass
1714

1815
if module_available("torch"):
19-
import torch
20-
from torch.nn import Module
16+
pass
2117
else:
2218
torch = None
2319

@@ -63,60 +59,18 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
6359
return Teamspace(**teamspaces[requested_teamspace])
6460

6561

66-
def upload_model(
67-
model: Union[str, Path, "Module"],
68-
name: str,
69-
progress_bar: bool = True,
70-
cluster_id: Optional[str] = None,
71-
staging_dir: Optional[str] = None,
72-
) -> UploadedModelInfo:
73-
"""Upload a checkpoint to the model store.
74-
75-
Args:
76-
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
77-
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
78-
where entity is either your username or the name of an organization you are part of.
79-
progress_bar: Whether to show a progress bar for the upload.
80-
cluster_id: The name of the cluster to use. Only required if it can't be determined
81-
automatically.
82-
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
83-
be created and used.
84-
85-
"""
86-
if not staging_dir:
87-
staging_dir = tempfile.mkdtemp()
88-
# if LightningModule and isinstance(model, LightningModule):
89-
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90-
# model.save_checkpoint(path)
91-
if torch and isinstance(model, Module):
92-
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
93-
torch.save(model.state_dict(), path)
94-
elif isinstance(model, str):
95-
path = model
96-
elif isinstance(model, Path):
97-
path = str(model)
98-
else:
99-
raise ValueError(f"Unsupported model type {type(model)}")
100-
return upload_model_files(
101-
path=path,
102-
name=name,
103-
progress_bar=progress_bar,
104-
cluster_id=cluster_id,
105-
)
106-
107-
10862
def upload_model_files(
109-
path: str,
11063
name: str,
64+
path: str,
11165
progress_bar: bool = True,
11266
cluster_id: Optional[str] = None,
11367
) -> UploadedModelInfo:
11468
"""Upload a local checkpoint file to the model store.
11569
11670
Args:
117-
path: Path to the model file to upload.
118-
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
71+
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
11972
where entity is either your username or the name of an organization you are part of.
73+
path: Path to the model file to upload.
12074
progress_bar: Whether to show a progress bar for the upload.
12175
cluster_id: The name of the cluster to use. Only required if it can't be determined
12276
automatically.
@@ -132,15 +86,15 @@ def upload_model_files(
13286
)
13387

13488

135-
def download_model(
89+
def download_model_file(
13690
name: str,
13791
download_dir: str = ".",
13892
progress_bar: bool = True,
13993
) -> str:
14094
"""Download a checkpoint from the model store.
14195
14296
Args:
143-
name: Name tag of the model to download. Must be in the format 'organization/teamspace/modelname'
97+
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
14498
where entity is either your username or the name of an organization you are part of.
14599
download_dir: A path to directory where the model should be downloaded. Defaults
146100
to the current working directory.

src/litmodels/io/gateway.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os
2+
import tempfile
3+
from pathlib import Path
4+
from typing import Union, Optional
5+
6+
from lightning_sdk.api.teamspace_api import UploadedModelInfo
7+
from torch.nn import Module
8+
9+
from litmodels.io.cloud import torch, upload_model_files, download_model_file
10+
11+
12+
def upload_model(
13+
name: str,
14+
model: Union[str, Path, "Module"],
15+
progress_bar: bool = True,
16+
cluster_id: Optional[str] = None,
17+
staging_dir: Optional[str] = None,
18+
) -> UploadedModelInfo:
19+
"""Upload a checkpoint to the model store.
20+
21+
Args:
22+
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
23+
where entity is either your username or the name of an organization you are part of.
24+
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
25+
progress_bar: Whether to show a progress bar for the upload.
26+
cluster_id: The name of the cluster to use. Only required if it can't be determined
27+
automatically.
28+
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
29+
be created and used.
30+
31+
"""
32+
if not staging_dir:
33+
staging_dir = tempfile.mkdtemp()
34+
# if LightningModule and isinstance(model, LightningModule):
35+
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
36+
# model.save_checkpoint(path)
37+
if torch and isinstance(model, Module):
38+
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
39+
torch.save(model.state_dict(), path)
40+
elif isinstance(model, str):
41+
path = model
42+
elif isinstance(model, Path):
43+
path = str(model)
44+
else:
45+
raise ValueError(f"Unsupported model type {type(model)}")
46+
return upload_model_files(
47+
path=path,
48+
name=name,
49+
progress_bar=progress_bar,
50+
cluster_id=cluster_id,
51+
)
52+
53+
def download_model(
54+
name: str,
55+
download_dir: str = ".",
56+
progress_bar: bool = True,
57+
) -> str:
58+
"""Download a checkpoint from the model store.
59+
60+
Args:
61+
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
62+
where entity is either your username or the name of an organization you are part of.
63+
download_dir: A path to directory where the model should be downloaded. Defaults
64+
to the current working directory.
65+
progress_bar: Whether to show a progress bar for the download.
66+
67+
Returns:
68+
The absolute path to the downloaded model file or folder.
69+
"""
70+
return download_model_file(
71+
name=name,
72+
download_dir=download_dir,
73+
progress_bar=progress_bar,
74+
)

0 commit comments

Comments
 (0)