Skip to content

Commit 3357462

Browse files
authored
simple upload & load TorchScript model (#66)
* simple upload TorchScript model * load and .ts
1 parent 8c20cc6 commit 3357462

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/litmodels/io/gateway.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def upload_model(
4545
# if LightningModule and isinstance(model, LightningModule):
4646
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
4747
# model.save_checkpoint(path)
48-
if torch and isinstance(model, Module):
48+
if torch and isinstance(model, torch.jit.ScriptModule):
49+
path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts")
50+
model.save(path)
51+
elif torch and isinstance(model, Module):
4952
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
5053
torch.save(model.state_dict(), path)
5154
elif isinstance(model, str):
@@ -109,4 +112,6 @@ def load_model(name: str, download_dir: str = ".") -> Any:
109112
model_path = Path(os.path.join(download_dir, download_paths[0]))
110113
if model_path.suffix.lower() == ".pkl":
111114
return joblib.load(model_path)
115+
if model_path.suffix.lower() == ".ts":
116+
return torch.jit.load(model_path)
112117
raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.")

tests/test_io_cloud.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import joblib
55
import pytest
6+
import torch
7+
import torch.jit as torch_jit
68
from litmodels import download_model, load_model, upload_model
79
from litmodels.io import upload_model_files
810
from sklearn import svm
@@ -30,6 +32,7 @@ def test_download_wrong_model_name(name):
3032
[
3133
("path/to/checkpoint", "path/to/checkpoint", False),
3234
# (BoringModel(), "%s/BoringModel.ckpt"),
35+
(torch_jit.script(Module()), f"%s{os.path.sep}RecursiveScriptModule.ts", True),
3336
(Module(), f"%s{os.path.sep}Module.pth", True),
3437
(svm.SVC(), f"%s{os.path.sep}SVC.pkl", 1),
3538
],
@@ -68,7 +71,7 @@ def test_download_model(mock_download_model):
6871

6972

7073
@mock.patch("litmodels.io.cloud.sdk_download_model")
71-
def test_load_model(mock_download_model, tmp_path):
74+
def test_load_model_pickle(mock_download_model, tmp_path):
7275
# create a dummy model file
7376
model_file = tmp_path / "dummy_model.pkl"
7477
test_data = svm.SVC()
@@ -84,3 +87,22 @@ def test_load_model(mock_download_model, tmp_path):
8487
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
8588
)
8689
assert isinstance(model, svm.SVC)
90+
91+
92+
@mock.patch("litmodels.io.cloud.sdk_download_model")
93+
def test_load_model_torch_jit(mock_download_model, tmp_path):
94+
# create a dummy model file
95+
model_file = tmp_path / "dummy_model.ts"
96+
test_data = torch_jit.script(Module())
97+
test_data.save(model_file)
98+
mock_download_model.return_value = [str(model_file.name)]
99+
100+
# The lit-logger function is just a wrapper around the SDK function
101+
model = load_model(
102+
name="org-name/teamspace/model-name",
103+
download_dir=str(tmp_path),
104+
)
105+
mock_download_model.assert_called_once_with(
106+
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
107+
)
108+
assert isinstance(model, torch.jit.ScriptModule)

0 commit comments

Comments
 (0)